diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..513e66581496adc0ee5d90eddc42cabe83436e2b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,57 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
+llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o filter=lfs diff=lfs merge=lfs -text
+llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o filter=lfs diff=lfs merge=lfs -text
+llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o filter=lfs diff=lfs merge=lfs -text
+static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
+static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
+static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
+static/audio/audio2.wav filter=lfs diff=lfs merge=lfs -text
+static/chat/audio1.mp3 filter=lfs diff=lfs merge=lfs -text
+static/chat/audio2.mp3 filter=lfs diff=lfs merge=lfs -text
+static/emergent/audio1.wav filter=lfs diff=lfs merge=lfs -text
+static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
+static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav filter=lfs diff=lfs merge=lfs -text
+static/speech/audio3.wav filter=lfs diff=lfs merge=lfs -text
+static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav filter=lfs diff=lfs merge=lfs -text
+static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav filter=lfs diff=lfs merge=lfs -text
+static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav filter=lfs diff=lfs merge=lfs -text
+static/think/audio1.wav filter=lfs diff=lfs merge=lfs -text
+static/think/audio2.wav filter=lfs diff=lfs merge=lfs -text
+static/voice/voice_2.mp3 filter=lfs diff=lfs merge=lfs -text
+static/speech/speaker1.flac filter=lfs diff=lfs merge=lfs -text
+static/speech/videoplayback.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5007c25dfd9c5976effddde197cfe2089ef8d035
--- /dev/null
+++ b/README.md
@@ -0,0 +1,147 @@
+---
+license: other
+title: Audio Flamingo 3 Demo
+sdk: gradio
+emoji: ๐
+colorFrom: green
+colorTo: green
+pinned: true
+short_description: Audio Flamingo 3 Demo
+---
+
+
+
+
+ Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models
+
+
+
+
+
+
+
+
+
+
+
+## Overview
+
+This repo contains the PyTorch implementation of [Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models](). Audio Flamingo 3 (AF3) is a fully open, state-of-the-art Large Audio-Language Model (LALM) that advances reasoning and understanding across speech, sounds, and music. AF3 builds on previous work with innovations in:
+
+- Unified audio representation learning (speech, sound, music)
+- Flexible, on-demand chain-of-thought reasoning (Thinking in Audio)
+- Long-context audio comprehension (including speech and up to 10 minutes)
+- Multi-turn, multi-audio conversational dialogue (AF3-Chat)
+- Voice-to-voice interaction (AF3-Chat)
+
+Extensive evaluations confirm AF3โs effectiveness, setting new benchmarks on over 20 public audio understanding and reasoning tasks.
+
+
+## Main Results
+
+Audio Flamingo 3 outperforms prior SOTA models including GAMA, Audio Flamingo, Audio Flamingo 2, Qwen-Audio, Qwen2-Audio, Qwen2.5-Omni.LTU, LTU-AS, SALMONN, AudioGPT, Gemini Flash v2 and Gemini Pro v1.5 on a number of understanding and reasoning benchmarks.
+
+
+
+
+
+
+
+
+
+## Audio Flamingo 3 Architecture
+
+Audio Flamingo 3 uses AF-Whisper unified audio encoder, MLP-based audio adaptor, Decoder-only LLM backbone (Qwen2.5-7B), and Streaming TTS module (AF3-Chat).
+Audio Flamingo 3 can take up to 10 minutes of audio inputs.
+
+
+
+
+
+## Installation
+
+```bash
+./environment_setup.sh af3
+```
+
+## Code Structure
+
+- The folder ```audio_flamingo_3/``` contains the main training and inference code of Audio Flamingo 3.
+- The folder ```audio_flamingo_3/scripts``` contains the inference scripts of Audio Flamingo 3 in case you would like to use our pretrained checkpoints on HuggingFace.
+
+Each folder is self-contained and we expect no cross dependencies between these folders. This repo does not contain the code for Streaming-TTS pipeline which will released in the near future.
+
+## Single Line Inference
+
+To infer stage 3 model directly, run the command below:
+```bash
+python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav
+```
+
+To infer the model in stage 3.5 model, run the command below:
+```bash
+python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --model-path /path/to/checkpoint/af3-7b/stage35 --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav --peft-mode
+```
+
+## References
+
+The main training and inferencing code within each folder are modified from [NVILA](https://github.com/NVlabs/VILA/tree/main) [Apache license](incl_licenses/License_1.md).
+
+## License
+
+- The code in this repo is under [MIT license](incl_licenses/MIT_license.md).
+- The checkpoints are for non-commercial use only [NVIDIA OneWay Noncommercial License](incl_licenses/NVIDIA_OneWay_Noncommercial_License.docx). They are also subject to the [Qwen Research license](https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/LICENSE), the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset.
+- Notice: Audio Flamingo 3 is built with Qwen-2.5. Qwen is licensed under the Qwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
+
+
+## Citation
+
+- Audio Flamingo 2
+```
+@article{ghosh2025audio,
+ title={Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities},
+ author={Ghosh, Sreyan and Kong, Zhifeng and Kumar, Sonal and Sakshi, S and Kim, Jaehyeon and Ping, Wei and Valle, Rafael and Manocha, Dinesh and Catanzaro, Bryan},
+ journal={arXiv preprint arXiv:2503.03983},
+ year={2025}
+}
+```
+
+- Audio Flamingo
+```
+@inproceedings{kong2024audio,
+ title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities},
+ author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan},
+ booktitle={International Conference on Machine Learning},
+ pages={25125--25148},
+ year={2024},
+ organization={PMLR}
+}
+```
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed35e40c41d5cbcc5e0085abb7623f9dc409f7d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,306 @@
+import gradio as gr
+import torch
+import llava
+from peft import PeftModel
+import os
+from huggingface_hub import snapshot_download
+import copy
+
+# ---------------------------------
+# SINGLE-TURN MODEL SETUP
+# ---------------------------------
+
+MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
+MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
+
+model_single = llava.load(MODEL_BASE_SINGLE, model_base=None)
+model_single_copy = copy.deepcopy(model_single)
+
+# Move the model to GPU
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model_single = model_single.to_empty(device=device)
+
+generation_config_single = model_single.default_generation_config
+
+model_think = PeftModel.from_pretrained(
+ model_single,
+ MODEL_BASE_THINK,
+ device_map="auto",
+ torch_dtype=torch.float16,
+)
+model_think.to(device)
+
+# # ---------------------------------
+# # MULTI-TURN MODEL SETUP
+# # ---------------------------------
+# MODEL_BASE_MULTI = snapshot_download(repo_id="nvidia/audio-flamingo-3-chat")
+# model_multi = llava.load(MODEL_BASE_MULTI, model_base=None, devices=[0])
+# generation_config_multi = model_multi.default_generation_config
+
+
+# ---------------------------------
+# SINGLE-TURN INFERENCE FUNCTION
+# ---------------------------------
+def single_turn_infer(audio_file, prompt_text):
+ try:
+ sound = llava.Sound(audio_file)
+ full_prompt = f"\n{prompt_text}"
+ response = model_single_copy.generate_content([sound, full_prompt], generation_config=generation_config_single)
+ return response
+ except Exception as e:
+ return f"โ Error: {str(e)}"
+
+# ---------------------------------
+# MULTI-TURN INFERENCE FUNCTION
+# ---------------------------------
+# def multi_turn_chat(user_input, audio_file, history, current_audio):
+# try:
+# if audio_file is not None:
+# current_audio = audio_file # Update state if a new file is uploaded
+
+# if current_audio is None:
+# return history + [("System", "โ Please upload an audio file before chatting.")], history, current_audio
+
+# sound = llava.Sound(current_audio)
+# prompt = f"\n{user_input}"
+
+# response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi)
+
+# history.append((user_input, response))
+# return history, history, current_audio
+# except Exception as e:
+# history.append((user_input, f"โ Error: {str(e)}"))
+# return history, history, current_audio
+
+def think_infer(audio_file, prompt_text):
+ try:
+ sound = llava.Sound(audio_file)
+ full_prompt = f"\n{prompt_text}"
+ response = model_think.generate_content([sound, full_prompt], generation_config=generation_config_single)
+ return response
+ except Exception as e:
+ return f"โ Error: {str(e)}"
+
+# ---------------------------------
+# MULTI-TURN INFERENCE FUNCTION
+# ---------------------------------
+# def multi_turn_chat(user_input, audio_file, history, current_audio):
+# try:
+# if audio_file is not None:
+# current_audio = audio_file # Update state if a new file is uploaded
+
+# if current_audio is None:
+# return history + [("System", "โ Please upload an audio file before chatting.")], history, current_audio
+
+# sound = llava.Sound(current_audio)
+# prompt = f"\n{user_input}"
+
+# response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi)
+
+# history.append((user_input, response))
+# return history, history, current_audio
+# except Exception as e:
+# history.append((user_input, f"โ Error: {str(e)}"))
+# return history, history, current_audio
+# ---------------------------------
+# INTERFACE
+# ---------------------------------
+with gr.Blocks(css="""
+.gradio-container {
+ max-width: 100% !important;
+ width: 100% !important;
+ margin: 0 !important;
+ padding: 0 !important;
+}
+#component-0, .gr-block.gr-box {
+ width: 100% !important;
+}
+.gr-block.gr-box, .gr-column, .gr-row {
+ padding: 0 !important;
+ margin: 0 !important;
+}
+""") as demo:
+
+ with gr.Column():
+ gr.HTML("""
+
+
+
Audio Flamingo 3
+
Advancing Audio Intelligence with Fully Open Large Audio-Language Models
+
+
+
+
+
+""")
+ # gr.Markdown("#### NVIDIA (2025)")
+
+ with gr.Tabs():
+ # ---------------- SINGLE-TURN ----------------
+ with gr.Tab("๐ฏ Single-Turn Inference"):
+ with gr.Row():
+ with gr.Column():
+ audio_input_single = gr.Audio(type="filepath", label="Upload Audio")
+ prompt_input_single = gr.Textbox(label="Prompt", placeholder="Ask a question about the audio...", lines=8)
+ btn_single = gr.Button("Generate Answer")
+
+ gr.Examples(
+ examples=[
+ ["static/emergent/audio1.wav", "What is surprising about the relationship between the barking and the music?"],
+ ["static/audio/audio2.wav", "Please describe the audio in detail."],
+ ["static/speech/audio3.wav", "Transcribe any speech you hear."],
+ ],
+ inputs=[audio_input_single, prompt_input_single],
+ label="๐งช Try Examples"
+ )
+
+ with gr.Column():
+ output_single = gr.Textbox(label="Model Response", lines=15)
+
+ btn_single.click(fn=single_turn_infer, inputs=[audio_input_single, prompt_input_single], outputs=output_single)
+ with gr.Tab("๐ค Think / Long"):
+
+ with gr.Row():
+ with gr.Column():
+ audio_input_think = gr.Audio(type="filepath", label="Upload Audio")
+ prompt_input_think = gr.Textbox(label="Prompt", placeholder="To enable thinking, please add the text: '\nPlease think and reason about the input music before you respond.' to your prompt.", lines=8)
+ btn_think = gr.Button("Generate Answer")
+
+ gr.Examples(
+ examples=[
+ ["static/think/audio1.wav", "What are the two people doing in the audio Choose the correct option from the following options:\n(A) One person is demonstrating how to use the equipment\n(B) The two people are discussing how to use the equipment\n(C) The two people are disassembling the equipment\n(D) One person is teaching another person how to use a piece of equipment\n"],
+ ["static/think/audio2.wav", "Is the boat in the video moving closer or further away? Choose the correct option from the following options:\n(A) Closer\n(B) Further\n"],
+ ["static/speech/videoplayback.wav", "Generate a detailed caption for the input audio, describing all notable speech, sound, and musical events comprehensively. In the caption, transcribe all spoken content by all speakers in the audio precisely."],
+ ["static/speech/speaker1.flac", "Transcribe any input speech in the input audio."],
+ ],
+ inputs=[audio_input_think, prompt_input_think],
+ label="๐งช Try Examples"
+ )
+
+ with gr.Column():
+ output_think = gr.Textbox(label="Model Response", lines=30)
+
+ btn_think.click(fn=think_infer, inputs=[audio_input_think, prompt_input_think], outputs=output_think)
+ # ---------------- MULTI-TURN CHAT ----------------
+ with gr.Tab("๐ฌ Multi-Turn Chat"):
+ # chatbot = gr.Chatbot(label="Audio Chatbot")
+ # audio_input_multi = gr.Audio(type="filepath", label="Upload or Replace Audio Context")
+ # user_input_multi = gr.Textbox(label="Your message", placeholder="Ask a question about the audio...", lines=8)
+ # btn_multi = gr.Button("Send")
+ # history_state = gr.State([]) # Chat history
+ # current_audio_state = gr.State(None) # Most recent audio file path
+
+ # btn_multi.click(
+ # fn=multi_turn_chat,
+ # inputs=[user_input_multi, audio_input_multi, history_state, current_audio_state],
+ # outputs=[chatbot, history_state, current_audio_state]
+ # )
+ # gr.Examples(
+ # examples=[
+ # ["static/chat/audio1.mp3", "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?"],
+ # ["static/chat/audio2.mp3", "Switching gears, this one is super energetic and synthetic. If I wanted to remix the calming folk piece into something closer to this, what would you suggest?"],
+ # ],
+ # inputs=[audio_input_multi, user_input_multi],
+ # label="๐งช Try Examples"
+ # )
+ # Add the link to another Gradio demo here
+ gr.Markdown("๐ [Check out our other Gradio demo here](https://huggingface.co/spaces/nvidia/audio-flamingo-3-chat)")
+
+ with gr.Tab("๐ฃ๏ธ Speech Prompt"):
+ # gr.Markdown("Use your **voice** to talk to the model.")
+
+ # with gr.Row():
+ # with gr.Column():
+ # speech_input = gr.Audio(type="filepath", label="Speak or Upload Audio")
+ # btn_speech = gr.Button("Submit")
+ # gr.Examples(
+ # examples=[
+ # ["static/voice/voice_0.mp3"],
+ # ["static/voice/voice_1.mp3"],
+ # ["static/voice/voice_2.mp3"],
+ # ],
+ # inputs=speech_input,
+ # label="๐งช Try Examples"
+ # )
+ # with gr.Column():
+ # response_box = gr.Textbox(label="Model Response", lines=15)
+
+ # btn_speech.click(fn=speech_prompt_infer, inputs=speech_input, outputs=response_box)
+ # Add the link to another Gradio demo here
+ gr.Markdown("๐ [Check out our other Gradio demo here](https://huggingface.co/spaces/nvidia/audio-flamingo-3-chat)")
+
+ # ---------------- ABOUT ----------------
+ with gr.Tab("๐ About"):
+ gr.Markdown("""
+### ๐ Overview
+
+**Audio Flamingo 3** is a fully open state-of-the-art (SOTA) large audio-language model that advances reasoning and understanding across speech, sound, and music. AF3 introduces:
+
+(i) AF-Whisper, a unified audio encoder trained using a novel strategy for joint representation learning across all 3 modalities of speech, sound, and music;
+
+(ii) flexible, on-demand thinking, allowing the model to do chain-of-thought reasoning before answering;
+
+(iii) multi-turn, multi-audio chat;
+
+(iv) long audio understanding and reasoning (including speech) up to 10 minutes; and
+
+(v) voice-to-voice interaction.
+
+To enable these capabilities, we propose several large-scale training datasets curated using novel strategies, including AudioSkills-XL, LongAudio-XL, AF-Think, and AF-Chat, and train AF3 with a novel five-stage curriculum-based training strategy. Trained on only open-source audio data, AF3 achieves new SOTA results on over 20+ (long) audio understanding and reasoning benchmarks, surpassing both open-weight and closed-source models trained on much larger datasets.
+
+**Key Features:**
+
+๐ก Audio Flamingo 3 has strong audio, music and speech understanding capabilities.
+
+๐ก Audio Flamingo 3 supports on-demand thinking for chain-of-though reasoning.
+
+๐ก Audio Flamingo 3 supports long audio and speech understanding for audios up to 10 minutes.
+
+๐ก Audio Flamingo 3 can have multi-turn, multi-audio chat with users under complex context.
+
+๐ก Audio Flamingo 3 has voice-to-voice conversation abilities.
+
+
+""")
+
+ gr.Markdown("ยฉ 2025 NVIDIA | Built with โค๏ธ using Gradio + PyTorch")
+
+
+# -----------------------
+# Launch App
+# -----------------------
+if __name__ == "__main__":
+ demo.launch(share=True)
diff --git a/llava/__init__.py b/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c14ac15313d1fceee30d9a3ef2d69bbba5702722
--- /dev/null
+++ b/llava/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .entry import *
+from .media import *
diff --git a/llava/cli/infer_audio.py b/llava/cli/infer_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..8041b2ab9aa52fc166c87ce4527ad5a8356857da
--- /dev/null
+++ b/llava/cli/infer_audio.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import importlib.util
+import json
+import os
+
+from pydantic import BaseModel
+from termcolor import colored
+
+import llava
+from llava import conversation as clib
+from llava.media import Image, Video, Sound
+from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat
+from peft import PeftModel
+import torch
+
+def get_schema_from_python_path(path: str) -> str:
+ schema_path = os.path.abspath(path)
+ spec = importlib.util.spec_from_file_location("schema_module", schema_path)
+ schema_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(schema_module)
+
+ # Get the Main class from the loaded module
+ Main = schema_module.Main
+ assert issubclass(
+ Main, BaseModel
+ ), f"The provided python file {path} does not contain a class Main that describes a JSON schema"
+ return Main.schema_json()
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-base", "-mb", type=str, required=True)
+ parser.add_argument("--model-path", "-mp", type=str, required=True)
+ parser.add_argument("--conv-mode", "-c", type=str, default="auto")
+ parser.add_argument("--text", type=str)
+ parser.add_argument("--media", type=str, nargs="+")
+ parser.add_argument("--json-mode", action="store_true")
+ parser.add_argument("--peft-mode", action="store_true")
+ parser.add_argument("--json-schema", type=str, default=None)
+ args = parser.parse_args()
+
+ # Convert json mode to response format
+ if not args.json_mode:
+ response_format = None
+ elif args.json_schema is None:
+ response_format = ResponseFormat(type="json_object")
+ else:
+ schema_str = get_schema_from_python_path(args.json_schema)
+ print(schema_str)
+ response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str))
+
+ # Load model
+ model = llava.load(args.model_base)
+ if args.peft_mode:
+ model = PeftModel.from_pretrained(
+ model,
+ args.model_path,
+ device_map="auto",
+ torch_dtype=torch.float16,
+ )
+ # Set conversation mode
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
+
+ # Prepare multi-modal prompt
+ prompt = []
+ if args.media is not None:
+ for media in args.media or []:
+ if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]):
+ media = Sound(media)
+ else:
+ raise ValueError(f"Unsupported media type: {media}")
+ prompt.append(media)
+ if args.text is not None:
+ prompt.append(args.text)
+
+ # Generate response
+ response = model.generate_content(prompt, response_format=response_format)
+ print(colored(response, "cyan", attrs=["bold"]))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llava/constants.py b/llava/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e557b88bf7437b49bb88c754d71218ed79dfae
--- /dev/null
+++ b/llava/constants.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+DEFAULT_SOUND_TOKEN = ""
+DEFAULT_SPEECH_TOKEN = ""
+SENTINEL_TOKEN = ""
+
+MEDIA_TOKENS = {
+ "speech": "",
+ "sound": "",
+}
+
+
+"""
+151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151651: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151652: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+
+"""
+NUM_EXTRA_TOKENS = 10
diff --git a/llava/conversation.py b/llava/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b42e5bf88571231b262040d3601e63006e125e9
--- /dev/null
+++ b/llava/conversation.py
@@ -0,0 +1,197 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+from llava.utils.logging import logger
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+
+ AUTO = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_3 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
+ ret = self.system + self.sep
+ for rid, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message = message[0]
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
+ ret += role + message + sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version,
+ )
+
+
+conv_auto = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(),
+ sep_style=SeparatorStyle.AUTO,
+ sep="\n",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(),
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+hermes_2 = Conversation(
+ system="<|im_start|>system\nAnswer the questions.",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+ messages=(),
+ version="hermes-2",
+)
+
+# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
+llama_3_chat = Conversation(
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
+ version="llama_v3",
+ messages=(),
+ sep_style=SeparatorStyle.LLAMA_3,
+ sep="<|eot_id|>",
+ sep2="<|end_of_text|>",
+)
+
+
+default_conversation = conv_auto
+conv_templates = {
+ "auto": conv_auto,
+ "hermes-2": hermes_2,
+ "llama_3": llama_3_chat,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "plain": conv_llava_plain,
+}
+
+
+CONVERSATION_MODE_MAPPING = {
+ "vila1.5-3b": "vicuna_v1",
+ "vila1.5-8b": "llama_3",
+ "vila1.5-13b": "vicuna_v1",
+ "vila1.5-40b": "hermes-2",
+ "llama-3": "llama_3",
+ "llama3": "llama_3",
+}
+
+
+def auto_set_conversation_mode(model_name_or_path: str) -> str:
+ global default_conversation
+ for k, v in CONVERSATION_MODE_MAPPING.items():
+ if k in model_name_or_path.lower():
+ logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
+ default_conversation = conv_templates[v]
+ return
diff --git a/llava/data/__init__.py b/llava/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22d60dab57c08f90e224f3d95047c5dfbf14de6d
--- /dev/null
+++ b/llava/data/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .builder import *
+from .dataset import *
+from .datasets_mixture import *
diff --git a/llava/data/base.py b/llava/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..925a500cf78667b3ab862a72ab1b564e86c2c746
--- /dev/null
+++ b/llava/data/base.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import random
+from typing import Any, Dict, List
+
+import torch
+from torch.utils.data import Dataset
+from transformers import PreTrainedTokenizer
+
+from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images
+from llava.train.args import DataArguments
+from llava.utils.logging import logger
+from llava.utils.media import extract_media
+from llava.utils.tokenizer import preprocess_conversation
+
+__all__ = ["BaseDataset"]
+
+def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(speech)
+
+def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(sound)
+
+def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(sound_masks)
+
+
+class BaseDataset(Dataset):
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ data_args: DataArguments,
+ no_system_prompt: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.no_system_prompt = no_system_prompt
+ self.instances = []
+ self.enable_dynamic_res = False
+ self.enable_dynamic_res_s2 = False
+ # global_batch_size: int,
+ self.global_batch_size = kwargs.get("global_batch_size", 1)
+
+ # by default, dataset cls will resample on failure
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
+
+ # by default, dataset cls will resample on failure
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
+
+ def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]:
+ raise NotImplementedError
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ instance = self.instances[index]
+
+ try:
+ # Process instance to conversation
+ conversation = self.process(instance)
+
+ # Extract media from conversation
+ media, media_meta = extract_media(conversation, self.data_args)
+
+ if "speech" in media:
+ processed_speech = _process_speech(media["speech"], self.data_args)
+ if "sound" in media:
+ processed_sound = _process_sound(media["sound"], self.data_args)
+ processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args)
+ processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args)
+ # Prepare "input_ids" and "labels" for training
+ data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt)
+
+ if "speech" in media:
+ data["speech"] = processed_speech
+ if "sound" in media:
+ data["sound"] = processed_sound
+ data["sound_feature_masks"] = processed_sound_feature_masks
+ data["sound_embed_masks"] = processed_sound_embed_masks
+
+ except Exception as e:
+ if not self.resample_on_failure:
+ raise e
+ else:
+ logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.")
+ return self.__getitem__(random.randint(0, len(self.instances) - 1))
+
+ return data
+
+ def __len__(self) -> int:
+ return len(self.instances)
diff --git a/llava/data/builder.py b/llava/data/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..871d380c089e5948a58a668846ecf96faaee4afa
--- /dev/null
+++ b/llava/data/builder.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+import os.path as osp
+from itertools import chain
+from typing import Any, List, Optional
+
+import torch
+import torch.distributed as dist
+from hydra.utils import instantiate
+from torch.utils.data import ConcatDataset, Dataset
+from transformers import PreTrainedTokenizer
+
+from llava.data.datasets_mixture import DATASETS_LEGACY
+from llava.train.args import DataArguments, TrainingArguments
+from llava.utils import io
+from llava.utils.logging import logger
+import time
+import numpy as np
+__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"]
+
+
+def load_dataset_yaml(name):
+ fname = f"{name}.yaml" if not name.endswith(".yaml") else name
+
+ # yaml under llava/data/registry/datasets
+ repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname)
+ if osp.exists(repo_path):
+ return repo_path
+
+ # # yaml under
+ abs_path = osp.expanduser(fname)
+ if osp.exists(abs_path):
+ return abs_path
+
+ raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.")
+
+
+def register_datasets(name: Optional[str] = None):
+ if name is None:
+ name = os.environ.get("VILA_DATASETS", "default")
+ logger.info(f"Registering datasets from environment: '{name}'.")
+ # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml"))
+ dataset_meta = {}
+ for _name in name.split(","):
+ yamlpath = load_dataset_yaml(_name)
+ logger.info(f"Registering datasets from: '{yamlpath}'.")
+ meta = io.load(yamlpath)
+ dataset_meta.update(meta)
+ return dataset_meta
+
+
+def register_mixtures():
+ return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml"))
+
+
+DATASETS = register_datasets()
+MIXTURES = register_mixtures()
+
+
+def parse_mixture(mixture: str) -> List[str]:
+ names = mixture.split("+") if "+" in mixture else [mixture]
+ while any(name in MIXTURES for name in names):
+ names = list(chain(*[MIXTURES.get(name, [name]) for name in names]))
+ return sorted(names)
+
+
+class SubsetDataset(Dataset):
+ def __init__(self, dataset: Dataset, limit: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.limit = limit
+
+ def __len__(self) -> int:
+ return int(len(self.dataset) * self.limit)
+
+ def __getitem__(self, index: int) -> Any:
+ return self.dataset[index % len(self.dataset)]
+
+class RepeatedDataset(Dataset):
+ def __init__(self, dataset: Dataset, times: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.times = times
+
+ def __len__(self) -> int:
+ return len(self.dataset) * self.times
+
+ def __getitem__(self, index: int) -> Any:
+ return self.dataset[index % len(self.dataset)]
+
+
+def get_world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ else:
+ return 1
+
+
+def build_dataset(
+ mixture: str,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ tokenizer: PreTrainedTokenizer,
+) -> Dataset:
+ logger.warning(f"Training VILA with mixture '{mixture}'.")
+ datasets = []
+ dataset_rng = np.random.default_rng(1234)
+ for name in parse_mixture(mixture):
+
+ if "*" in name:
+ name, times = name.split("*")
+ times = int(times)
+ else:
+ times = 1
+ limit_dataset = False
+ if "#" in name:
+ # we limit the max length of this dataset
+ name, max_length_percent = name.split("#")
+ limit_dataset = True
+ if DATASETS is not None and name in DATASETS:
+ if name in DATASETS_LEGACY:
+ logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.")
+ dataset = instantiate(DATASETS[name], _partial_=True)(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ global_batch_size=(
+ training_args.per_device_train_batch_size
+ # * torch.distributed.get_world_size()
+ * get_world_size()
+ * training_args.gradient_accumulation_steps
+ ),
+ )
+ elif name in DATASETS_LEGACY:
+ logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.")
+ dataset = build_dataset_legacy(
+ name,
+ data_args=data_args,
+ training_args=training_args,
+ tokenizer=tokenizer,
+ )
+ else:
+ raise ValueError(f"Dataset '{name}' is not found in the registries.")
+
+
+ if limit_dataset:
+ # we limit the max length of this dataset
+ max_length = int(float(int(max_length_percent) / 100.) * len(dataset))
+ dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.))
+
+ if times > 1:
+ dataset = RepeatedDataset(dataset, times)
+ datasets.append(dataset)
+ return ConcatDataset(datasets)
+
+
+def build_dataset_legacy(
+ name: str,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ tokenizer: PreTrainedTokenizer,
+) -> Dataset:
+ from llava.data.dataset import (
+ LazySupervisedDataset,
+ LazyWDSDataset,
+ )
+
+ dataset = DATASETS_LEGACY[name]
+ dataset_type = dataset.dataset_type
+ if dataset_type == "torch":
+ dataset_cls = LazySupervisedDataset
+ elif dataset_type == "wds":
+ dataset_cls = LazyWDSDataset
+ else:
+ raise NotImplementedError(f"{dataset_type} is not supported.")
+
+ data_args.meta_path = getattr(dataset, "meta_path", None)
+ data_args.caption_choice = getattr(dataset, "caption_choice", None)
+ data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None)
+ data_args.start_idx = getattr(dataset, "start_idx", None)
+ data_args.end_idx = getattr(dataset, "end_idx", None)
+
+ return dataset_cls(
+ tokenizer=tokenizer,
+ data_path=dataset.data_path,
+ image_folder=getattr(dataset, "image_path"),
+ data_args=data_args,
+ training_args=training_args,
+ )
diff --git a/llava/data/collate.py b/llava/data/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..4795b77a246acce9294f07b189f0a06725b71f2c
--- /dev/null
+++ b/llava/data/collate.py
@@ -0,0 +1,166 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+from typing import Any, Dict, Sequence
+
+import torch
+from transformers import PreTrainedTokenizer
+
+from llava.constants import IGNORE_INDEX
+from llava.utils.logging import logger
+
+__all__ = ["DataCollator"]
+
+
+@dataclass
+class DataCollator:
+ tokenizer: PreTrainedTokenizer
+
+ def __init__(self, tokenizer: PreTrainedTokenizer):
+ super().__init__()
+ self.tokenizer = tokenizer
+
+ def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
+ # Gather everything from the batch
+ input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, []
+
+ media_meta = {}
+
+ media_meta["sound_feature_masks"] = []
+ media_meta["sound_embed_masks"] = []
+ media_meta["frame_times"] = []
+ for instance in instances:
+ if isinstance(instance["input_ids"], torch.Tensor):
+ input_ids.append(instance["input_ids"])
+ labels.append(instance["labels"])
+ for name in media:
+ objs = instance.get(name)
+ objs = objs if objs is not None else []
+ media[name].append([obj for obj in objs])
+ if instance.get("sound") is not None:
+ for name_k in media_meta:
+ if "sound" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if instance.get("video") is not None or instance.get("image") is not None:
+ for name_k in media_meta:
+ if "frame" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if "block_sizes" in instance:
+ block_sizes.append(instance["block_sizes"])
+ else:
+ block_sizes.append(
+ [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else []
+ )
+ else:
+ input_ids.extend(instance["input_ids"])
+ labels.extend(instance["labels"])
+ for name in media:
+ objs = instance.get(name)
+ objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))]
+ media[name].extend(objs)
+ if instance.get("sound") is not None:
+ for name_k in media_meta:
+ if "sound" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].extend(objs)
+ if instance.get("video") is not None or instance.get("image") is not None:
+ for name_k in media_meta:
+ if "frame" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if "block_sizes" in instance:
+ block_sizes.extend(instance["block_sizes"])
+ else:
+ block_sizes.extend(
+ [[None for _ in range(len(objs))] for objs in instance.get("image")]
+ if instance.get("image") is not None
+ else [[] for _ in range(len(instance["input_ids"]))]
+ )
+
+ batch_size = len(input_ids)
+
+
+ # Check if the number of media objects (or the number of block sizes) matches the number of media tokens
+ for name in media:
+ for k in range(batch_size):
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
+ actual = len(block_sizes[k])
+ else:
+ actual = len(media[name][k])
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ if actual != expected:
+ raise ValueError(
+ f"Number mismatch between {name} objects and {name} tokens. "
+ f"There are {expected} {name} tokens but {actual} {name} objects."
+ )
+
+ # Batchify the inputs
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(
+ labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX,
+ )
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
+ labels = labels[:, : self.tokenizer.model_max_length]
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
+
+ # Truncate media objects if necessary
+ for name in media:
+ objects = []
+ for k in range(batch_size):
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
+ actual = len(media[name][k])
+ num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]])
+ num_small_scale_blocks = actual - num_large_scale_blocks
+ num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k])
+ expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ expected = (
+ sum([x * y for x, y in block_sizes[k][:expected_full_image]])
+ + num_small_scale_blocks_each_img * expected_full_image
+ )
+ if actual > expected:
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
+ media[name][k] = media[name][k][:expected]
+ objects.extend(media[name][k])
+ block_sizes[k] = block_sizes[k][:expected_full_image]
+ else:
+ actual = len(media[name][k])
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ if actual > expected:
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
+ media[name][k] = media[name][k][:expected]
+ objects.extend(media[name][k])
+ if name == "image":
+ block_sizes[k] = block_sizes[k][:expected]
+ media[name] = objects
+
+ for name in media_meta:
+ objects = []
+ for k in range(batch_size):
+ try:
+ objects.extend(media_meta[name][k])
+ except:
+ continue
+ media_meta[name] = objects
+
+ # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...]
+ block_sizes = sum(block_sizes, [])
+ return {
+ "input_ids": input_ids,
+ "media": media,
+ "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}},
+ "labels": labels,
+ "attention_mask": attention_mask,
+ "media_meta": media_meta,
+ }
diff --git a/llava/data/dataset.py b/llava/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e549c6c66fe7db2ce2173f078cdf2d70f291cfd8
--- /dev/null
+++ b/llava/data/dataset.py
@@ -0,0 +1,1635 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import copy
+import io
+import json
+import os
+import os.path as osp
+import random
+import time
+import warnings
+from dataclasses import dataclass
+from typing import Dict, Sequence
+import math
+import numpy as np
+import PIL
+import torch
+import transformers
+from PIL import Image, ImageFile
+from torch.utils.data import Dataset, default_collate
+from transformers import PreTrainedTokenizer
+from transformers import AutoFeatureExtractor
+import kaldiio
+import llava.data.datasets_mixture as datasets_mixture
+from llava import conversation as conversation_lib
+from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX
+from llava.data.collate import DataCollator
+from llava.mm_utils import (
+ load_audio,
+ get_num_windows,
+ tokenizer_image_token,
+)
+from torchvision import transforms
+from llava.train.args import DataArguments, TrainingArguments
+from llava.train.sequence_parallel import (
+ extract_local_from_list,
+ extract_local_input_ids,
+ extract_local_position_ids,
+ get_pg_manager,
+)
+from llava.utils.tokenizer import preprocess_conversation
+# import torchaudio
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
+import soundfile as sf
+from librosa import resample as librosa_resample
+import whisper
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+PIL.Image.MAX_IMAGE_PIXELS = 1000000000
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+
+
+def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ concat_values = "".join([sentence["value"] for sentence in source])
+ for sid, sentence in enumerate(source):
+ # In multimodal conversations, we automatically prepend '' at the start of the first sentence if it doesn't already contain one.
+
+ if DEFAULT_SOUND_TOKEN in sentence["value"]:
+ sentence["value"] = sentence["value"].replace(DEFAULT_SOUND_TOKEN, f"{DEFAULT_SOUND_TOKEN}\n")
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SOUND_TOKEN}\n\n", f"{DEFAULT_SOUND_TOKEN}\n")
+ if DEFAULT_SPEECH_TOKEN in sentence["value"]:
+ sentence["value"] = sentence["value"].replace(DEFAULT_SPEECH_TOKEN, f"{DEFAULT_SPEECH_TOKEN}\n")
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SPEECH_TOKEN}\n\n", f"{DEFAULT_SPEECH_TOKEN}\n")
+ return sources
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ no_system_prompt: bool = False,
+) -> Dict:
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ return default_collate(
+ [
+ preprocess_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
+ for conversation in sources
+ ]
+ )
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is originally implemented by the LLaVA team and modified by
+ Ji Lin and Haotian Tang.
+ """
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ try:
+ with open(data_path) as fp:
+ list_data_dict = json.load(fp)
+ except:
+ with open(data_path) as fp:
+ list_data_dict = [json.loads(q) for q in fp]
+
+ # rank0_print("Formatting inputs...Skip in lazy mode")
+ print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+ self.image_folder = image_folder
+ self.wav_processor = AutoFeatureExtractor.from_pretrained('Qwen/Qwen2-Audio-7B')
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ if 'duration' in sample.keys():
+ duration = sample["duration"]
+ else:
+ duration = 10.
+ try:
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + int(math.ceil(duration * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ except:
+ try:
+ cur_len = 0 + int(math.ceil(duration * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ except:
+ cur_len = 0 + int(math.ceil(10. * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ @staticmethod
+ def _load_sound(sound_file, wav_processor, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=3, audio_start = 0.0):
+ if sound_file is None:
+ return None
+ window_length = int(window_length * sample_rate)
+ window_overlap = int(window_overlap * sample_rate)
+ max_num_window = int(max_num_window)
+ duration = max_num_window * (window_length - window_overlap) + window_overlap
+
+ sound_outputs = []
+ audio_feature_masks = []
+ audio_embed_masks = []
+
+ try:
+ sound_filename = str.split(sound_file, '/')[-1]
+ if '.ark' in sound_filename:
+ sound = kaldiio.load_mat(sound_file)
+ audio_data = sound[1]
+ audio_data=audio_data.astype(np.float16)
+ else:
+ audio_data = load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration
+ T = len(audio_data)
+ audio_data = audio_data.reshape(1, -1)
+ num_windows, full_length = get_num_windows(T, sample_rate, max_num_window)
+
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
+
+ for i in range(num_windows):
+ audio_embed_mask = torch.zeros(750)
+ start = i * (window_length - window_overlap)
+ audio_data_tensor_this = audio_data_tensor[:, start:start+window_length]
+ orig_length = audio_data_tensor_this.shape[1]
+ audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") #
+ sound_outputs.append(audio_data_tensor_this["input_features"])
+ # calculate the mask for the input melspec to Whisper
+ melspec_frames_this_window = int(math.ceil(orig_length / 160))
+ feature_attention_mask = torch.zeros(3000, dtype=torch.int32)
+ feature_attention_mask[:melspec_frames_this_window] = 1
+ audio_feature_masks.append(feature_attention_mask.unsqueeze(0))
+ # calculate the mask for the output embedding for use in AF2
+ conv_lengths = (melspec_frames_this_window - 1) // 2 + 1
+ output_embedding_lengths = (conv_lengths - 2) // 2 + 1
+ audio_embed_mask[:output_embedding_lengths] = 1
+ audio_embed_masks.append(audio_embed_mask)
+ except:
+ print('error loading file', sound_file)
+ sound_outputs.append(torch.zeros(1,128,3000))
+ audio_feature_masks.append(torch.zeros(1,3000, dtype=torch.int32))
+ audio_embed_masks.append(torch.zeros(750))
+
+ return torch.stack(sound_outputs, dim=0), torch.stack(audio_feature_masks, dim=0), torch.stack(audio_embed_masks, dim=0)
+
+ @staticmethod
+ def _load_speech(speech_path,sample_rate=16000):
+ if speech_path is None:
+ return None
+
+ speech_outputs = []
+ try:
+ speech = whisper.load_audio(speech_path)
+ speech = whisper.pad_or_trim(speech)
+ mel = whisper.log_mel_spectrogram(speech)
+ speech_outputs.append(mel.unsqueeze(0))
+ except:
+ speech_outputs.append(torch.zeros(1,80,3000))
+ return torch.stack(speech_outputs, dim=0)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ import re
+ if "sound" in self.list_data_dict[i]:
+ # chat data loading
+ if isinstance(self.list_data_dict[i]["sound"],list):
+ sound_files = self.list_data_dict[i]["sound"]
+ conversations_raw = self.list_data_dict[i]["conversations"]
+
+ # Step 1: Extract tags in order of appearance
+ sound_tag_pattern = re.compile(r"")
+ ordered_sound_tags = []
+
+ for turn in conversations_raw:
+ tags = sound_tag_pattern.findall(turn["value"])
+ ordered_sound_tags.extend([f"" for tag in tags])
+
+ # Step 2: Load sound tensors in the order of tags
+ sound_tensor = []
+ audio_feature_masks = []
+ audio_embed_masks = []
+ sound_token_map = {}
+
+ for tag in ordered_sound_tags:
+ idx = int(tag.split('-')[1][:-1])
+ if tag not in sound_token_map:
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ this_sound_tensor = this_sound_tensor.squeeze(1) # (windows x 750 x 2048)
+ sound_token_map[tag] = ("\n" * this_sound_tensor.shape[0]).rstrip()
+ sound_tensor.append(this_sound_tensor)
+ audio_feature_masks.append(af_mask)
+ audio_embed_masks.append(ae_mask)
+ else:
+ # If already loaded, still append to match sequence
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ this_sound_tensor = this_sound_tensor.squeeze(1)
+ sound_tensor.append(this_sound_tensor)
+ audio_feature_masks.append(af_mask)
+ audio_embed_masks.append(ae_mask)
+
+
+ # Process conversations and inject sound markers
+ conversation = []
+ for turn in conversations_raw:
+ role = turn["from"]
+ value = turn["value"]
+
+ # Replace any tag with corresponding repeated \n
+ for tag, sound_token in sound_token_map.items():
+ value = value.replace(tag, sound_token)
+
+ conversation.append({
+ "from": role,
+ "value": value.rstrip()
+ })
+
+ sources = [conversation]
+ sound_tensor = torch.cat(sound_tensor, dim=0)
+ audio_feature_masks = torch.cat(audio_feature_masks, dim=0)
+ audio_embed_masks = torch.cat(audio_embed_masks, dim=0)
+ else:
+ sound_file = self.list_data_dict[i]["sound"]
+ question = str(self.list_data_dict[i]["conversations"][0]["value"].rstrip())
+ answer = str(self.list_data_dict[i]["conversations"][1]["value"]).rstrip()
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ sound_tensor, audio_feature_masks, audio_embed_masks = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ sound_tensor=sound_tensor.squeeze(1) # squeeze the irrelevant dimension which was caused due to processor getting 1 batch for processing --> (windows x 750 x 2048)
+ question = "\n" * sound_tensor.shape[0] + question
+ conversation = [
+ {"from": "human", "value": question},
+ {"from": "gpt", "value": answer},
+ ]
+
+ sources = [conversation]
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=(
+ "sound" in self.list_data_dict[i]
+ ),
+ )
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ if "sound" in self.list_data_dict[i]:
+ data_dict["sound"] = sound_tensor
+ data_dict["sound_feature_masks"] = audio_feature_masks
+ data_dict["sound_embed_masks"] = audio_embed_masks
+ if "speech" in self.list_data_dict[i]:
+ data_dict["speech"] = speech_tensor
+
+ return data_dict
+
+
+class LazyMMC4Dataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Haotian Tang."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ image_following_text_only=False,
+ text_only=False,
+ ):
+ super().__init__()
+
+ import pickle
+
+ n_samples = []
+ # actually shards and stats info
+ n_shards = len(os.listdir(data_path)) // 2
+ # n_shards = 100
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
+
+ print("total MMC4 samples", sum(n_samples)) # 10,881,869
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
+ shard_names = shard_names[shard_start:shard_end]
+
+ full_data_list = []
+ # now load data
+ for shard_name in shard_names:
+ # load shard
+ with open(os.path.join(data_path, shard_name), "rb") as f:
+ data_list = pickle.load(f)
+
+ full_data_list.extend(data_list)
+
+ print(f"* loaded totally {len(full_data_list)} samples")
+
+ self.data_list = full_data_list
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ self.image_following_text_only = image_following_text_only
+ self.text_only = text_only
+
+ def __len__(self):
+ # return len(self.data_list)
+ return self.n_samples
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for info in self.data_list:
+ num_images = min(6, len(info["image_info"]))
+ sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]]
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = num_images * self.num_image_tokens // 2 + sum([len(x) for x in sentences])
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ info = self.data_list[i - self.idx_offset]
+
+ sentences = info["text_list"]
+ # kentang-mit@: remove existing tokens in the sentences
+ for ix in range(len(sentences)):
+ # if this is an html tag, we still preserve its semantic meaning
+ sentences[ix] = sentences[ix].replace("", "")
+ sim_matrix = info["similarity_matrix"] # we do not use this...
+
+ # convert images from base64 to PIL and filter based on image-text similarity
+ images, sentence_ixs = [], []
+ if not self.text_only:
+ for sample_image, sim_vec in zip(info["image_info"], sim_matrix):
+ image_base64 = sample_image["image_base64"]
+ rawbytes = base64.b64decode(image_base64)
+
+ sim_ix = sample_image["matched_text_index"]
+ # sim_ix = np.argmax(sim_vec)
+ # sim_score = sim_vec[sim_ix]
+
+ # filter to images >= 5KB
+ # if len(rawbytes) // 1000 <= 5:
+ # continue
+ # if sim_score < 0.24:
+ # continue
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
+
+ images.append(image)
+ sentence_ixs.append(sim_ix)
+
+ # constrain max num 6 images
+ max_num_images = 6
+ if len(images) > max_num_images:
+ images = images[:max_num_images]
+ sentence_ixs = sentence_ixs[:max_num_images]
+
+ # reorder images according to text insertion
+ images = [images[iii] for iii in np.argsort(sentence_ixs)]
+
+ # preprocess and tokenize text
+ for ix in sentence_ixs:
+ sentences[ix] = f"\n{sentences[ix]}"
+
+ if self.image_following_text_only:
+ # use pad tokens to divide sentence pieces
+ text = self.tokenizer.pad_token.join(sentences)
+ else:
+ text = " ".join(sentences)
+ # whitespace cleanup
+ text = text.replace(" ", "").replace(" ", "")
+ text = f"{text}{self.tokenizer.eos_token}" # add eos token
+
+ if len(images) > 0:
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
+ images, block_sizes = dynamic_s2_process_images_and_prompt(
+ images, text, self.data_args, self.image_folder
+ )
+ elif self.data_args.image_aspect_ratio == "dynamic":
+ images, text = dynamic_process_images_and_prompt(
+ images, text, self.data_args, self.image_folder, max_tiles=6
+ )
+ else:
+ images = torch.stack([process_image(image, self.data_args, self.image_folder) for image in images])
+
+ # the same size for all images, so we concat
+ # cur_token_len = (
+ # images[0].shape[-2] // self.multimodal_cfg["patch_size"]
+ # ) * (images[0].shape[-1] // self.multimodal_cfg["patch_size"])
+ # cur_token_len += self.multimodal_cfg["n_extra_patch"]
+ else:
+ images = None
+ # cur_token_len = 0
+
+ input_ids = tokenizer_image_token(
+ text,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+
+ image_token_id = self.tokenizer.media_token_ids["image"]
+
+ # now check the case where the last token is image patch token
+ if input_ids[-1] == image_token_id: # need to remove one last image
+ last_non_im_patch_indices = torch.where(input_ids != image_token_id)[0][-1] + 1
+ input_ids = input_ids[:last_non_im_patch_indices]
+
+ n_im_patch = (input_ids == image_token_id).sum().item()
+
+ if self.data_args.image_aspect_ratio != "dynamic_s2":
+ images = images[:n_im_patch]
+ assert len(images) == n_im_patch, print(text, input_ids)
+ assert len(input_ids.shape) == 1, "Unexpected shape of 'input_ids' from MMC4."
+ input_ids = (
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids])
+ if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id
+ else input_ids
+ )
+ targets = input_ids.clone()
+
+ if self.image_following_text_only: # keep only text after leading image token
+ # remove loss for any token before the first token
+ label_idx = 0
+ while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id:
+ targets[label_idx] = IGNORE_INDEX
+ label_idx += 1
+
+ pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0]
+
+ pad_token_idxs = torch.where(targets == pad_token)[0]
+ for pad_token_idx in pad_token_idxs:
+ token_idx = pad_token_idx + 1
+ while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id:
+ targets[token_idx] = IGNORE_INDEX
+ token_idx += 1
+ # do not train on padding tokens
+ targets[targets == pad_token] = IGNORE_INDEX
+
+ # mask image tokens is unnecessary for llava-1.5
+ # targets[targets == IMAGE_TOKEN_INDEX] = IGNORE_INDEX
+ # print(input_ids.shape)
+
+ data_dict = dict(input_ids=input_ids, labels=targets, image=images)
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
+ data_dict["block_sizes"] = block_sizes
+
+ return data_dict
+
+
+class LazyCoyoDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Haotian Tang."""
+
+ num_image_tokens = 576
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
+ n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ import pickle
+
+ n_samples = []
+ # actually shards and stats info
+ n_shards = len(os.listdir(data_path)) // 2
+ # n_shards = 100
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
+
+ print("total COYO samples", sum(n_samples))
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+
+ gpu_samples = [
+ sum(n_samples[i * shared_size : (i + 1) * shared_size]) // n_samples_per_idx for i in range(world_size)
+ ]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
+ shard_names = shard_names[shard_start:shard_end]
+
+ full_data_list = []
+ # now load data
+ for shard_name in shard_names:
+ # load shard
+ with open(os.path.join(data_path, shard_name), "rb") as f:
+ shard_data = pickle.load(f)
+ random.seed(42)
+ if "mmc4" in data_path:
+ random.shuffle(shard_data) # shuffle for MMC4cap only
+ full_data_list.extend(shard_data)
+
+ print(f"* loaded totally {len(full_data_list)} samples")
+
+ # now pack the samples into groups
+ n_groups = len(full_data_list) // n_samples_per_idx
+ full_data_list = [
+ full_data_list[i : i + n_samples_per_idx] for i in range(0, len(full_data_list), n_samples_per_idx)
+ ]
+ if len(full_data_list[-1]) < n_samples_per_idx:
+ full_data_list = full_data_list[:-1]
+ assert len(full_data_list) == n_groups
+ print(f"split into {n_groups} groups")
+
+ self.data_list = full_data_list
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ def __len__(self):
+ # return len(self.data_list)
+ return self.n_samples
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ CONCAT_SAMPLES = False
+ info_list = self.data_list[i - self.idx_offset]
+
+ text_list = []
+ image_list = []
+
+ for sample in info_list:
+ caption_key = (
+ "text" if "text" in sample else "caption"
+ ) # kentang-mit@: remove existing tokens in the sentences
+ # kentang-mit@: remove existing token.
+ # if this is an html tag, we still preserve its semantic meaning
+ sample[caption_key] = sample[caption_key].replace("", "")
+ text_list.append(DEFAULT_IMAGE_TOKEN + "\n" + sample[caption_key] + self.tokenizer.eos_token)
+ if "image" in sample:
+ image_base64 = sample["image"]
+ rawbytes = base64.b64decode(image_base64)
+ else:
+ rawbytes = sample["rawbytes"]
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
+ image_list.append(image)
+
+ image_list = torch.stack([process_image(image, self.data_args, self.image_folder) for image in image_list])
+
+ if CONCAT_SAMPLES:
+ # into capcap...
+ text_list = "".join(text_list)
+
+ input_ids = self.tokenizer(
+ text_list,
+ return_tensors="pt",
+ padding="longest",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids # 4, seq_len
+
+ input_ids = input_ids[0]
+
+ else:
+ input_ids = [
+ tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ for prompt in text_list
+ ]
+ # print([x.shape[0] for x in input_ids], [len(x.split()) for x in text_list], [len(re.findall(r"]*>", x)) for x in text_list])
+
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
+ # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ # )
+
+ targets = copy.deepcopy(input_ids)
+ for i in range(len(targets)):
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
+
+
+class LazyWDSDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Ligeng Zhu."""
+
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ image_folder: str,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ n_samples = []
+ n_shards = len(os.listdir(data_path)) // 3
+ for shard in range(n_shards):
+ with open(os.path.join(data_path, f"{shard:05d}_stats.json")) as f:
+ info = json.load(f)
+ n_samples.append(info["successes"])
+
+ # print(f"[DEBUG] {data_path} total samples", sum(n_samples)) # 10,881,869
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+ print("rank", rank, "world_size", world_size, "shared_size", shared_size)
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ tar_list = [f"{shard_idx:05d}.tar" for shard_idx in range(shard_start, shard_end)]
+
+ self.data_list = []
+ t1 = time.time()
+ for tar in tar_list:
+ tmp_path = f"/tmp/ccs{tar}"
+ tar_path = os.path.join(data_path, tar)
+
+ if PROCESS_GROUP_MANAGER is not None:
+ dist.barrier()
+ if PROCESS_GROUP_MANAGER.sp_rank == 0:
+ os.makedirs(tmp_path, exist_ok=True)
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
+ dist.barrier()
+ else:
+ os.makedirs(tmp_path, exist_ok=True)
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
+
+ txt_list = [f for f in os.listdir(tmp_path) if f.endswith(".txt")]
+
+ for txt in txt_list:
+ caption = open(os.path.join(tmp_path, txt)).read().strip()
+ image_path = os.path.join(tmp_path, txt.split(".")[0] + ".jpg")
+ self.data_list.append({"caption": caption, "image": image_path})
+ t2 = time.time()
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ def __len__(self):
+ return self.n_samples
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+
+ # print("i", i, "idx_offset", self.idx_offset, "len", len(self.data_list))
+ info = self.data_list[i - self.idx_offset]
+ caption, image_path = info["caption"], info["image"]
+
+ rand_prompt = "\n"
+ sources = [
+ {
+ "image": image_path,
+ "conversations": [
+ {"from": "human", "value": rand_prompt},
+ {"from": "gpt", "value": caption},
+ ],
+ }
+ ]
+
+ # one example of sources
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
+ if "image" in sources[0]:
+ image = process_image(sources[0]["image"], self.data_args, self.image_folder)
+ image = torch.unsqueeze(image, dim=0)
+ # now random pick some context samples for training
+ if hasattr(self.data_args, "num_shots"):
+ if self.data_args.num_shots > 0:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if image is not None:
+ data_dict["image"] = image
+ else:
+ raise NotImplementedError
+
+ return data_dict
+
+
+class LazyCCSWebDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ligeng Zhu."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ t1 = time.time()
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path))
+
+ t2 = time.time()
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ # info = self.data_list[i - self.idx_offset]
+ # caption, image_path = info["caption"], info["image"]
+ info = self.dataset[i]
+ if ".jpg" in info:
+ caption, image_path = info[".txt"], info[".jpg"]
+ elif ".png" in info:
+ caption, image_path = info[".txt"], info[".png"]
+ elif ".webp" in info:
+ caption, image_path = info[".txt"], info[".webp"]
+ elif ".bmp" in info:
+ caption, image_path = info[".txt"], info[".bmp"]
+ elif ".tiff" in info:
+ caption, image_path = info[".txt"], info[".tiff"]
+ else:
+ print(info.keys())
+ print(info)
+ raise KeyError
+
+ caption = caption.replace("", "")
+ if isinstance(image_path, io.BytesIO):
+ image_path = Image.open(image_path).convert("RGB")
+
+ if not isinstance(image_path, PIL.Image.Image):
+ print(image_path)
+ print(info.keys())
+ print(type(image_path))
+ raise NotImplementedError
+
+ rand_prompt = "\n"
+ sources = [
+ {
+ "image": image_path,
+ "conversations": [
+ {"from": "human", "value": rand_prompt},
+ {"from": "gpt", "value": caption},
+ ],
+ }
+ ]
+
+ # one example of sources
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
+ if "image" in sources[0]:
+ image = process_image(sources[0]["image"], self.data_args, image_folder=None)
+ image = torch.unsqueeze(image, dim=0)
+ # now random pick some context samples for training
+ if hasattr(self.data_args, "num_shots"):
+ if self.data_args.num_shots > 0:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if image is not None:
+ data_dict["image"] = image
+ else:
+ raise NotImplementedError
+
+ return data_dict
+
+
+from functools import lru_cache
+
+
+@lru_cache(maxsize=16)
+def lru_json_load(fpath):
+ with open(fpath) as fp:
+ return json.load(fp)
+
+
+class LazyCoyoWebDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ligeng Zhu."""
+
+ num_image_tokens = 576
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
+ n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path), meta_path=data_args.meta_path)
+
+ if data_args.start_idx >= 0 and data_args.end_idx >= 0:
+ # Ligeng: support slicing for ablate different subsets.
+ total = len(self.dataset)
+ start_idx = int(total * data_args.start_idx)
+ end_idx = int(total * data_args.end_idx)
+ print(f"loading subset from {start_idx} to {end_idx}, total {total}")
+ self.dataset = torch.utils.data.Subset(self.dataset, range(start_idx, end_idx))
+
+ # For caption choice,
+ # if None: use original caption
+ # if a folder path: use specified caption to override original one (choice1)
+ # if a folder path: use specified caption and concat with original one (choice2)
+ self.caption_choice = None
+ self.caption_choice_2 = None
+ self.data_path = data_path
+
+ if data_args.caption_choice is not None:
+ self.caption_choice = data_args.caption_choice
+ print("[recap] Override coyo caption using ", self.caption_choice)
+
+ if data_args.caption_choice_2 is not None:
+ self.caption_choice_2 = data_args.caption_choice_2
+ print("[recapv2] Override coyo caption using ", self.caption_choice_2)
+
+ print("total samples", len(self.dataset))
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = (
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
+ ) # int(os.environ["RANK"])
+ world_size = (
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
+ ) # int(os.environ["WORLD_SIZE"])
+ print(
+ "rank",
+ rank,
+ "world_size",
+ world_size,
+ )
+
+ self.n_samples_per_idx = n_samples_per_idx
+ # self.n_samples = len(self.dataset) // n_samples_per_idx
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.dataset) // self.n_samples_per_idx
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ CONCAT_SAMPLES = False
+ # info_list = self.dataset[i - self.idx_offset]
+
+ begin_idx, end_idx = (
+ i * self.n_samples_per_idx,
+ (i + 1) * self.n_samples_per_idx,
+ )
+ end_idx = min(end_idx, len(self.dataset))
+
+ text_list = []
+ image_list = []
+
+ for idx in range(begin_idx, end_idx):
+ info = self.dataset[idx]
+ if ".jpg" in info:
+ caption, image_path = info[".txt"], info[".jpg"]
+ elif ".png" in info:
+ caption, image_path = info[".txt"], info[".png"]
+ elif ".webp" in info:
+ caption, image_path = info[".txt"], info[".webp"]
+ elif ".bmp" in info:
+ caption, image_path = info[".txt"], info[".bmp"]
+ elif ".tiff" in info:
+ caption, image_path = info[".txt"], info[".tiff"]
+ else:
+ print(info.keys())
+ print(info)
+ raise KeyError
+
+ if self.caption_choice is not None:
+ # load new captions
+ shard = info["__shard__"]
+ url = info[".json"]["url"]
+ tar_name = osp.relpath(osp.realpath(shard), osp.realpath(self.data_path))
+ # tar_name = osp.dirname(shard)
+ shard_json_path = osp.join(self.caption_choice, tar_name + ".json")
+ try:
+ shard_json = lru_json_load(shard_json_path)
+ try:
+ caption = shard_json[url]["output"]
+ except KeyError:
+ print(f"{url} not in caption. fallback to original caption temporarially")
+ except:
+ print(f"shard_json_path {shard_json_path} not found. fallback to original caption temporarially")
+ caption = caption.replace("", "")
+ text_list.append(DEFAULT_IMAGE_TOKEN + caption + self.tokenizer.eos_token)
+
+ if isinstance(image_path, io.BytesIO):
+ image_path = Image.open(image_path).convert("RGB")
+
+ if not isinstance(image_path, PIL.Image.Image):
+ print(image_path)
+ print(info.keys())
+ print(type(image_path))
+ raise NotImplementedError
+
+ image_list.append(image_path)
+
+ # image_list = torch.stack([process_image(image, self.data_args, image_folder=None) for image in image_list])
+ # NOTE(fix by ligeng)
+ # now image_list should return a list of image tensor where each has a dimension of (1, c, h, w)
+ image_list = [process_image(image, self.data_args, image_folder=None).unsqueeze(0) for image in image_list]
+
+ if CONCAT_SAMPLES:
+ # into capcap...
+ text_list = "".join(text_list)
+
+ input_ids = self.tokenizer(
+ text_list,
+ return_tensors="pt",
+ padding="longest",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids # 4, seq_len
+
+ input_ids = input_ids[0]
+ else:
+ input_ids = [
+ tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ for prompt in text_list
+ ]
+ input_ids = [
+ (
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids_i])
+ if input_ids_i[0] != self.tokenizer.bos_token_id
+ else input_ids_i
+ )
+ for input_ids_i in input_ids
+ ]
+
+ targets = copy.deepcopy(input_ids)
+ for i in range(len(targets)):
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
+
+
+class LazyVideoWebDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # cache_path: str,
+ # n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ # from llava.data.simple_video_dataset import SimpleVideoDataset
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(
+ data_path=osp.abspath(data_path),
+ meta_path=f"{osp.abspath(data_path)}/wids-meta.json",
+ # cache_dir=cache_path,
+ )
+
+ # None: use original caption
+ # Folder path: use original caption
+ self.caption_choice = None
+ self.data_path = data_path
+
+ if data_args.caption_choice is not None:
+ self.caption_choice = data_args.caption_choice
+ print("[recap] Override LazyVideo caption using ", self.caption_choice)
+
+ print("total samples", len(self.dataset))
+ # InternVid: TODO
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = (
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
+ ) # int(os.environ["RANK"])
+ world_size = (
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
+ ) # int(os.environ["WORLD_SIZE"])
+ print(
+ "rank",
+ rank,
+ "world_size",
+ world_size,
+ )
+ self.rank = rank
+ # rank = int(os.environ["RANK"]) if "RANK" in os.environ else 2
+ # world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 32
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ self.missing_uids = set()
+
+ def __len__(self):
+ return len(self.dataset)
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ ADD_TEXT_PROMPT = False
+ num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
+ loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0
+
+ info = self.dataset[i]
+
+ caption = ""
+ # print(info)
+ if ".mp4" in info:
+ caption, video_path = info[".txt"], info[".mp4"]
+ else:
+ video_path = None
+ caption = "Empty video."
+
+ images, frames_loaded, _ = LazySupervisedDataset._load_video(
+ video_path, num_video_frames, loader_fps, self.data_args
+ )
+
+ if frames_loaded == 0:
+ caption = "Empty video."
+
+ if self.caption_choice is not None:
+ shard = info["__shard__"]
+ uuid = osp.join(info["__shard__"], info["__key__"])
+ url = info["__key__"]
+ tar_name = osp.basename(info["__shard__"])
+
+ try:
+ shard_json_path = osp.join(self.caption_choice, tar_name.replace(".tar", ".json"))
+ shard_json = lru_json_load(shard_json_path)
+ caption = shard_json[url]["summary"]["output"]
+ except (KeyError, FileNotFoundError, json.decoder.JSONDecodeError):
+ if uuid not in self.missing_uids:
+ print("override caption not found for ", uuid)
+ self.missing_uids.add(uuid)
+
+ # print(f"[DEBUG {uuid}]", caption)
+
+ frames_loaded_successfully = len(images)
+ if caption is None:
+ caption = ""
+ prompt = "\n" * frames_loaded_successfully + caption
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+
+ input_ids = tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ targets = copy.deepcopy(input_ids)
+ data_dict = dict(input_ids=input_ids, labels=targets, image=image_tensor)
+
+ return data_dict
+
+
+class DataCollatorForSupervisedDatasetSeqParallel:
+ """Collate examples for supervised fine-tuning.
+ This class is originally implemented by the LLaVA team and
+ modified by Haotian Tang."""
+
+ def __init__(
+ self,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ sp_degree: int,
+ sp_rank: int,
+ ring_degree: int,
+ ring_type: str,
+ ):
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.training_args = training_args
+ self.sp_degree = sp_degree
+ self.sp_rank = sp_rank
+ self.ring_degree = ring_degree
+ self.ring_type = ring_type
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels, images = [], [], []
+ image_token_id = self.tokenizer.media_token_ids["image"]
+ video_token_id = self.tokenizer.media_token_ids["video"]
+
+ for instance in instances:
+ if not isinstance(instance["input_ids"], list):
+ input_ids.append(instance["input_ids"])
+ else:
+ input_ids += instance["input_ids"]
+ if not isinstance(instance["labels"], list):
+ labels.append(instance["labels"])
+ else:
+ labels += instance["labels"]
+ # Note (kentang-mit@: we do not directly push tensors to
+ # images, but list of tensors.
+ if "video" in instance:
+ instance["image"] = torch.cat(instance["video"])
+ video_id_pos = torch.where(input_ids[-1] == video_token_id)[0][0]
+ replace_ids = torch.Tensor(
+ ([image_token_id] + self.tokenizer.encode("\n")) * instance["image"].shape[0],
+ device=input_ids[-1].device,
+ )
+ input_ids[-1] = torch.cat(
+ [input_ids[-1][:video_id_pos], replace_ids, input_ids[-1][video_id_pos + 1 :]]
+ ).to(input_ids[-1].dtype)
+ labels[-1] = torch.cat(
+ [
+ labels[-1][:video_id_pos],
+ torch.Tensor([IGNORE_INDEX] * instance["image"].shape[0] * 2),
+ labels[-1][video_id_pos + 1 :],
+ ]
+ ).to(labels[-1].dtype)
+ instance.pop("video")
+
+ if "image" in instance:
+ cur_image = instance["image"]
+ assert len(cur_image.shape) == 4
+ # n_images, 3, size, size
+ if cur_image.shape[0] == 0:
+ warnings.warn("loaded one sample without images.")
+ if not isinstance(instance["input_ids"], list):
+ # datasets other than coyo, not packing >1 samples together
+ images.append(cur_image)
+ else:
+ # coyo-like datasets
+ images.extend(cur_image.chunk(cur_image.size(0), 0))
+ else:
+ warnings.warn("loaded one sample without images.")
+ images.append([])
+ # kentang-mit@: we need to make sure these two lists have
+ # the same length. We will use input_ids to filter out images corresponding
+ # to truncated tokens later.
+
+ max_num_images = max([len(_images) for _images in images])
+ for _images, _input_ids in zip(images, input_ids):
+ assert (
+ len(_images) == (_input_ids == image_token_id).sum().item()
+ ), f"Number mismatch between images and placeholder image tokens in 'len(_images) == (_input_ids == image_token_id).sum().item()'.\
+ Expect to have {len(_images)} images but only found {(_input_ids == image_token_id).sum().item()} images in tokens. \
+ Error input_ids: {_input_ids} {self.tokenizer.decode([x if x != -200 else 200 for x in _input_ids])}"
+
+ NUM_TOKENS_PER_IMAGE = self.data_args.num_image_tokens
+ if hasattr(self.data_args.image_processor, "crop_size"):
+ crop_size = self.data_args.image_processor.crop_size
+ else:
+ crop_size = self.data_args.image_processor.size
+
+ # Init the padding sample
+ seq_id = 0
+ while seq_id < len(input_ids):
+ # Skip the samples without images
+ dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device)
+ # dummy input_ids include one bos, one image token, and one eos
+ dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3])
+ dummy_input_ids[0] = self.tokenizer.bos_token_id
+ dummy_input_ids[1] = image_token_id
+ dummy_input_ids[2] = self.tokenizer.eos_token_id
+ dummy_labels = copy.deepcopy(dummy_input_ids)
+ dummy_labels[:2] = IGNORE_INDEX
+ dummy_seqlen = NUM_TOKENS_PER_IMAGE + 2 # TODO: Check the hard coding of 2
+ dummy_position_ids = torch.arange(start=0, end=dummy_seqlen, dtype=torch.int32)
+ break
+
+ # Sort with the real length of the sequence
+ combined = sorted(
+ zip(input_ids, labels, images),
+ key=lambda x: len(x[2]) * (NUM_TOKENS_PER_IMAGE - 1) + x[0].size(-1),
+ reverse=True, # Start Packing from the sequence with most images.
+ )
+ sorted_ids, sorted_labels, sorted_images = zip(*combined)
+ sorted_ids, sorted_labels, sorted_images = list(sorted_ids), list(sorted_labels), list(sorted_images)
+ max_seq_length = self.tokenizer.model_max_length # len(sorted_ids[0])
+ max_sample_len = 0
+
+ batches = []
+ label_batches = []
+ position_ids = []
+ batch_images = []
+ seqlens_in_batch = []
+
+ i = 0
+ while i < len(sorted_ids):
+ current_batch = torch.tensor([], dtype=torch.int32)
+ current_label_batch = torch.tensor([], dtype=torch.int32)
+ current_position_ids = torch.tensor([], dtype=torch.int32)
+ current_batch_images = []
+ current_num_images = 0
+ current_len = 0
+ current_num_samples = 0
+
+ # Pack a few samples into one sample
+ while i < len(sorted_ids):
+ num_images = (sorted_ids[i] == image_token_id).sum().item()
+ num_image_tokens_added = num_images * (NUM_TOKENS_PER_IMAGE - 1)
+ num_incoming_tokens = sorted_ids[i].size(-1) + num_image_tokens_added
+
+ # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree`
+ if self.ring_degree > 1:
+ RING_PAD_TOKEN_INDEX = 2
+ if self.ring_type == "ring_varlen":
+ if num_incoming_tokens % self.sp_degree != 0:
+ pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree
+ num_incoming_tokens += pad_len
+ # pad `input_ids`
+ pad_tensor = torch.full(
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
+ )
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
+
+ # pad `label`
+ pad_label_tensor = torch.full(
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
+ )
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
+ elif self.ring_type == "zigzag_ring_varlen":
+ self.zigzag_sp_degree = self.sp_degree * 2
+ if num_incoming_tokens % self.zigzag_sp_degree != 0:
+ pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree
+ num_incoming_tokens += pad_len
+ # pad `input_ids`
+ pad_tensor = torch.full(
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
+ )
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
+
+ # pad `label`
+ pad_label_tensor = torch.full(
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
+ )
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
+ else:
+ raise ValueError(f"Invalid ring_type: {self.ring_type}")
+
+ if num_incoming_tokens > max_seq_length:
+ print(
+ f"Warning: Skipping one packed sample with {num_incoming_tokens} tokens,\
+ please consider increase max seq len {max_seq_length}."
+ )
+ i += 1
+ continue
+
+ if (
+ (current_num_images == 0)
+ or (current_num_images < self.sp_degree)
+ or (
+ (current_num_images + num_images <= max_num_images)
+ and (current_len + num_incoming_tokens <= max_sample_len)
+ )
+ ) and (current_len + num_incoming_tokens <= max_seq_length):
+ current_num_images += num_images
+ current_len += num_incoming_tokens
+ current_num_samples += 1
+ current_position_ids = torch.cat(
+ (current_position_ids, torch.arange(start=0, end=num_incoming_tokens)), dim=0
+ )
+ current_batch = torch.cat((current_batch, sorted_ids[i]), dim=0)
+ sorted_labels[i][0] = IGNORE_INDEX
+ current_label_batch = torch.cat((current_label_batch, sorted_labels[i]), dim=0)
+ seqlens_in_batch.append(num_incoming_tokens)
+ current_batch_images.extend(sorted_images[i])
+ i += 1
+ assert current_num_images == len(current_batch_images)
+ else:
+ break
+
+ # Padding the sample with the dummy image sample, if there are no enough images
+ MAX_RETRY = self.sp_degree
+ num_retry = 0
+ while current_num_images < self.sp_degree and current_len < max_seq_length and num_retry <= MAX_RETRY:
+ current_num_images += dummy_image.size(0)
+ current_len += dummy_seqlen
+ current_num_samples += 1
+ current_position_ids = torch.cat((current_position_ids, dummy_position_ids), dim=0)
+ current_batch = torch.cat((current_batch, dummy_input_ids), dim=0)
+ current_label_batch = torch.cat((current_label_batch, dummy_labels), dim=0)
+ seqlens_in_batch.append(dummy_seqlen)
+ current_batch_images.extend(dummy_image)
+ # We pad from left side to ensure correct grad flow
+ # current_batch = torch.cat((dummy_input_ids, current_batch), dim=0)
+ # current_label_batch = torch.cat((dummy_labels, current_label_batch), dim=0)
+ # seqlens_in_batch.insert(0, dummy_seqlen)
+ # current_batch_images = torch.cat((dummy_image, current_batch_images), dim=0)
+ num_retry += 1
+
+ # Drop the samples that do not have enough images
+ if current_num_images < self.sp_degree:
+ print(f"Warning: Skipping one packed sample with {current_num_images} images")
+ seqlens_in_batch = seqlens_in_batch[:-current_num_samples]
+ continue
+
+ max_sample_len = max(max_sample_len, current_len)
+ batches.append(current_batch)
+ label_batches.append(current_label_batch)
+ position_ids.append(current_position_ids)
+ batch_images.append(current_batch_images)
+
+ try:
+ assert current_num_images == len(torch.where(current_batch == image_token_id)[0].tolist())
+ except AssertionError:
+ print(f"Error num_images on {self.sp_rank}", current_num_images)
+ print("current_batch", current_batch)
+ print(
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
+ len(torch.where(current_batch == image_token_id)[0].tolist()),
+ )
+ print(f"Error len(current_batch_images) on {self.sp_rank}:", len(current_batch_images))
+ raise AssertionError
+
+ # Split for sequence parallelism
+ for i in range(len(batches)):
+ image_token_indices = torch.where(batches[i] == image_token_id)[0].tolist()
+ image_ids = torch.arange(0, len(image_token_indices), dtype=torch.int32)
+ batches[i] = extract_local_input_ids(
+ batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
+ )
+ label_batches[i] = extract_local_input_ids(
+ label_batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
+ )
+ batch_images[i] = torch.concat(
+ extract_local_from_list(batch_images[i], self.sp_rank, self.sp_degree), dim=0
+ )
+ H, W = batch_images[i].size(-2), batch_images[i].size(-1)
+ batch_images[i] = batch_images[i].reshape(-1, 3, W, H)
+ num_images = len(batch_images[i])
+
+ try:
+ assert num_images == len(torch.where(batches[i] == image_token_id)[0].tolist())
+ except AssertionError:
+ print(f"Error num_images on {self.sp_rank}", num_images)
+ print("batches[i]", batches[i])
+ print(
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
+ len(torch.where(batches[i] == image_token_id)[0].tolist()),
+ )
+ print(f"Error batch_images[i] on {self.sp_rank}:", batch_images[i].shape)
+ raise AssertionError
+ position_ids[i] = extract_local_position_ids(
+ position_ids[i], image_token_indices, image_ids, self.sp_rank, self.sp_degree, NUM_TOKENS_PER_IMAGE - 1
+ )
+
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ batches, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(label_batches, batch_first=True, padding_value=IGNORE_INDEX)
+ seqlens_in_batch = [torch.tensor(x) for x in seqlens_in_batch]
+ seqlens_in_batch = torch.stack(seqlens_in_batch, axis=0)
+ seqlens_in_batch = seqlens_in_batch.flatten()
+ position_ids = torch.nn.utils.rnn.pad_sequence(position_ids, batch_first=True, padding_value=-1)
+
+ if batch_images:
+ batch_images = [torch.unbind(images) for images in batch_images]
+ flat_batch_images = [item for sublist in batch_images for item in sublist]
+ else:
+ flat_batch_images = None
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ # notice that we inject attention mask here
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ seqlens_in_batch=seqlens_in_batch,
+ media={"image": flat_batch_images},
+ media_config={"image": {}},
+ position_ids=position_ids,
+ )
+ return batch
+
+
+def make_supervised_data_module(
+ tokenizer: PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning.
+ This function is originally implemented by the LLaVA team and
+ modified by Jason Lu, Haotian Tang and Ligeng Zhu."""
+ datasets_mixture.register_datasets_mixtures()
+
+ from .builder import build_dataset
+
+ train_dataset = build_dataset(data_args.data_mixture, data_args, training_args, tokenizer)
+ training_args.sample_lens = [len(d) for d in train_dataset.datasets]
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is None:
+ data_collator = DataCollator(tokenizer=tokenizer)
+ else:
+ sp_degree = training_args.seq_parallel_size
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
+ data_collator = DataCollatorForSupervisedDatasetSeqParallel(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ training_args=training_args,
+ sp_degree=sp_degree,
+ sp_rank=sp_rank,
+ ring_degree=ring_degree,
+ ring_type=ring_type,
+ )
+
+ return dict(
+ train_dataset=train_dataset,
+ data_collator=data_collator,
+ )
diff --git a/llava/data/datasets_mixture.py b/llava/data/datasets_mixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a53fcf5fe69256e5cd4b005ab7f28f63d9fe7c
--- /dev/null
+++ b/llava/data/datasets_mixture.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import warnings
+from dataclasses import dataclass, field
+
+
+@dataclass
+class Dataset:
+ dataset_name: str
+ dataset_type: str = field(default="torch")
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+ meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."})
+ image_path: str = field(default=None, metadata={"help": "Path to the training image data."})
+ speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."})
+ caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."})
+ description: str = field(
+ default=None,
+ metadata={
+ "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset."
+ },
+ )
+ test_script: str = (None,)
+ maintainer: str = (None,)
+ ############## ############## ############## ############## ############## ##############
+ caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
+ caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
+ start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
+ end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
+
+
+DATASETS_LEGACY = {}
+
+
+def add_dataset(dataset):
+ if dataset.dataset_name in DATASETS_LEGACY:
+ # make sure the data_name is unique
+ warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.")
+ assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'."
+ DATASETS_LEGACY.update({dataset.dataset_name: dataset})
+
+
+def register_datasets_mixtures():
+ ############## ############## ############## ############## ############## ##############
+ # Audio Datasets
+ ############## ############## ############## ############## ############## ##############
+
+ data_mixture_1 = Dataset(
+ dataset_name="data_mixture_1",
+ dataset_type="torch",
+ data_path="/path/to/your/data_mixture_1/train.json",
+ )
+ add_dataset(data_mixture_1)
+
+ data_mixture_2 = Dataset(
+ dataset_name="data_mixture_2",
+ dataset_type="torch",
+ data_path="/path/to/your/data_mixture_2/train.json",
+ )
+ add_dataset(data_mixture_2)
+ # Add more data mixtures below
\ No newline at end of file
diff --git a/llava/data/registry/datasets/audio_test.yaml b/llava/data/registry/datasets/audio_test.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..144c86b4bee8d3bbe67b32b7e7948e6f3c8e4bb9
--- /dev/null
+++ b/llava/data/registry/datasets/audio_test.yaml
@@ -0,0 +1,97 @@
+---
+Clotho-AQA-AQA:
+ _target_: llava.data.LLaVADataset
+ data_path: Clotho-AQA-AQA/test.json
+Music-AVQA-AQA_All:
+ _target_: llava.data.LLaVADataset
+ data_path: Music-AVQA-AQA_All/test.json
+CochlScene-SceneClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: CochlScene-SceneClassification/test.json
+NSynth-Source:
+ _target_: llava.data.LLaVADataset
+ data_path: NSynth-Source/test.json
+NSynth-Instrument:
+ _target_: llava.data.LLaVADataset
+ data_path: NSynth-Instrument/test.json
+FSD50k-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: FSD50k-EventClassification/test.json
+Clotho-v2-AudioCaptioning:
+ _target_: llava.data.LLaVADataset
+ data_path: Clotho-v2-AudioCaptioning/test.json
+audiocaps-AudioCaptioning:
+ _target_: llava.data.LLaVADataset
+ data_path: audiocaps-AudioCaptioning/test.json
+ravdess-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: ravdess-EmotionClassification/val.json
+GTZAN-GenreClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: GTZAN-GenreClassification/test.json
+UrbanSound8K-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: UrbanSound8K-EventClassification/train.json
+Medley-solos-DB-InstrClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: Medley-solos-DB-InstrClassification/test.json
+ESC50-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: ESC50-EventClassification/train.json
+CREMA-D-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: CREMA-D-EmotionClassification/test.json
+IEMOCAP-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: IEMOCAP-EmotionClassification/test.json
+MELD-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: MELD-EmotionClassification/test.json
+MELD-SentimentClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: MELD-SentimentClassification/test.json
+MMAU:
+ _target_: llava.data.LLaVADataset
+ data_path: MMAU/test.json
+MMAU-mini:
+ _target_: llava.data.LLaVADataset
+ data_path: MMAU/test-mini.json
+AudioEntailmentQA:
+ _target_: llava.data.LLaVADataset
+ data_path: AudioEntailmentQA/test.json
+SPGI-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: SPGI-ASR/val.json
+SWBD-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: SWBD-ASR/val.json
+LibriSpeech-ASR-clean:
+ _target_: llava.data.LLaVADataset
+ data_path: LibriSpeech-ASR/test_clean.json
+LibriSpeech-ASR-other:
+ _target_: llava.data.LLaVADataset
+ data_path: LibriSpeech-ASR/test_other.json
+VoxPopuli-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: VoxPopuli-ASR/test.json
+Europarl-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: Europarl-ASR/test.json
+CV-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: CV-ASR/test.json
+GigaSpeech-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: GigaSpeech-ASR/test.json
+CompA-R-AQA:
+ _target_: llava.data.LLaVADataset
+ data_path: CompA-R-AQA/test.json
+MuschoMusicQA:
+ _target_: llava.data.LLaVADataset
+ data_path: MuschoMusicQA/test.json
+CMM:
+ _target_: llava.data.LLaVADataset
+ data_path: CMM/test.json
+AIR-Bench:
+ _target_: llava.data.LLaVADataset
+ data_path: AIR-Bench/test.json
diff --git a/llava/data/registry/datasets/default.yaml b/llava/data/registry/datasets/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dc18dfebe7b092e029ba6661c41925d503477ae
--- /dev/null
+++ b/llava/data/registry/datasets/default.yaml
@@ -0,0 +1,5 @@
+---
+dummy:
+ _target_: llava.data.DummyDataset
+ num_instances: 10000
+ comments: dummy dataset for testing
diff --git a/llava/data/registry/mixtures.yaml b/llava/data/registry/mixtures.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70f69accb593595bd978de144a8b571b6c24b3b6
--- /dev/null
+++ b/llava/data/registry/mixtures.yaml
@@ -0,0 +1,78 @@
+---
+audio_speech_all:
+ -CV-ASR_1
+ -MELD-EmotionClassification+
+ -BBCSoundEffects-AudioDescription
+ -SWBD-ASR_1
+ -WavCaps-SoundBible-AudioCaptioning
+ -AudioSet-Speech-Audio-QA
+ -SONYC-UST-EventClassification
+ -VoxPopuli-ASR_1
+ -FSD50k-EventClassification
+ -SalmonnQA
+ -emov-db-EmotionClassification
+ -LLARK_MagnaTagATune-mir+tess-EmotionClassification
+ -Europarl-ASR_1
+ -jl-corpus-EmotionClassification
+ -Ego-10-AudioCaptioning
+ -SPGI-ASR_1
+ -CREMA-D-EmotionClassification
+ -MusicBenchQA
+ -WavCaps-BBC_Sound_Effects-AudioCaptioning
+ -NSynth-Instrument
+ -SpokenSquadQA
+ -NSynth-MIR
+ -AudioEntailmentQA
+ -GigaSpeech-ASR_1
+ -WavCaps-AudioSet_SL-AudioCaptioning
+ -NonSpeech7k-EventClassification
+ -chime-home-EventClassification
+ -MusicCaps-AudioCaptioning
+ -LP-MusicCaps-MSD-AudioCaptioning
+ -Ego-30-AudioCaptioning
+ -NSynth-Source+Clotho-v2-AudioCaptioning
+ -LP-MusicCaps-MC-AudioCaptioning
+ -Clotho-AQA-EventClassification
+ -WavCaps-FreeSound-AudioCaptioning
+ -LLARK_MagnaTagATune-reasoning
+ -AudioSet-Temporal-Speech-Audio-QA
+ -TUT-EventClassification
+ -ESC50-EventClassification
+ -WavText5K-Tagging
+ -MELD-SentimentClassification
+ -Music-AVQA-AQA_All
+ -Music-AVQA-AVQA_All
+ -MACS-AudioCaptioning
+ -Medley-solos-DB-InstrClassification
+ -AudioSet-EventClassification
+ -OMGEmotion-EmotionClassification
+ -FMA-GenreClassification
+ -Epidemic_sound-AudioCaptioning
+ -CochlScene-SceneClassification
+ -LLARK_FMA-reasoning
+ -ravdess-EmotionClassification
+ -CompA-R-AQA
+ -MU-LLAMA-AQA
+ -musdbhq-InstrClassification
+ -UrbanSound8K-EventClassification
+ -audiocaps-AudioCaptioning
+ -VocalSound-VocalClassification
+ -CLAP_freesound-AudioCaptioning
+ -MMAUQA
+ -SongDescriber-AudioCaptioning
+ -HeySQuADQA
+ -Mira-AudioCaptioning
+ -Clotho-AQA-AQA
+ -LibriSpeech-ASR_1
+ -IEMOCAP-EmotionClassification
+ -AudioSetFullwoAudioMusicCaps-EventClassification
+ -MSP-PODCAST-Publish-1.9-EmotionClassification
+ -OpenAQA-AQA
+ -SoundDescs-AudioDescription
+ -LibriSQA
+ -LLARK_FMA-mir
+ -LP-MusicCaps-MTT-AudioCaptioning
+ -GTZAN-GenreClassification
+ -musdbhq-captioning
+ -YesNoQA
+
diff --git a/llava/entry.py b/llava/entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..44348e95eea2598a030f6953ff2098e133bed68f
--- /dev/null
+++ b/llava/entry.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import typing
+from typing import List, Optional
+
+if typing.TYPE_CHECKING:
+ from transformers import PreTrainedModel
+else:
+ PreTrainedModel = None
+
+__all__ = ["load"]
+
+
+def load(
+ model_path: str,
+ model_base: Optional[str] = None,
+ devices: Optional[List[int]] = None,
+ **kwargs,
+) -> PreTrainedModel:
+ import torch
+
+ from llava.conversation import auto_set_conversation_mode
+ from llava.mm_utils import get_model_name_from_path
+ from llava.model.builder import load_pretrained_model
+
+ auto_set_conversation_mode(model_path)
+
+ model_name = get_model_name_from_path(model_path)
+ model_path = os.path.expanduser(model_path)
+ if os.path.exists(os.path.join(model_path, "model")):
+ model_path = os.path.join(model_path, "model")
+
+ # Set `max_memory` to constrain which GPUs to use
+ if devices is not None:
+ assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set"
+ kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices})
+
+ model = load_pretrained_model(model_path, model_name, model_base, **kwargs)[1]
+ return model
diff --git a/llava/eval/__init__.py b/llava/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d807fb5a2b32ea471a6bfec0d57694f57f91e377
--- /dev/null
+++ b/llava/eval/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+
+from llava.utils import io
+
+__all__ = ["EVAL_ROOT", "TASKS"]
+
+
+EVAL_ROOT = "scripts/eval"
+TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry_audio.yaml"))
diff --git a/llava/eval/eval_audio_bench.py b/llava/eval/eval_audio_bench.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3017aef3ba10519f507e4d878d834c7ab3d7dbf
--- /dev/null
+++ b/llava/eval/eval_audio_bench.py
@@ -0,0 +1,117 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import csv
+import itertools
+import json
+import os
+
+import torch
+from datasets import load_dataset
+from tqdm import tqdm
+
+import llava
+from llava import conversation as conversation_lib
+from llava.data.builder import DATASETS
+from llava.eval.mmmu_utils.eval_utils import parse_choice
+from llava.utils import distributed as dist
+from llava.utils import io
+from llava.utils.logging import logger
+
+
+def load_existing_ids(output_file):
+ if not os.path.exists(output_file):
+ return set(), []
+ try:
+ with open(output_file, "r") as f:
+ lines = f.readlines()
+ outputs = [json.loads(line) for line in lines]
+ processed_ids = {item["id"] for item in outputs}
+ return processed_ids, outputs
+ except Exception as e:
+ print(f"Error loading existing outputs: {e}")
+ return set(), []
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default=None)
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--task", type=str, default=None)
+ parser.add_argument("--conv-mode", type=str, default="auto")
+ parser.add_argument("--generation-config", type=json.loads)
+ parser.add_argument("--output-dir", type=str, default=None)
+ args = parser.parse_args()
+
+ # Set up distributed environment
+ dist.init()
+ devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
+ torch.cuda.set_device(devices[0])
+
+ # Load stage 3 model with line 56
+ model = llava.load(args.model_base, model_base=None, devices=devices)
+ # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode
+ # model = PeftModel.from_pretrained(
+ # model,
+ # args.model_path,
+ # device_map="auto",
+ # torch_dtype=torch.float16,
+ # )
+ # Set up generation config
+ generation_config = model.default_generation_config
+ if args.generation_config is not None:
+ generation_config.update(**args.generation_config)
+
+ # Load data and chunk it
+ json_file = DATASETS[args.task]["data_path"]
+ instances = io.load(json_file)
+ instances = instances[dist.rank() :: dist.size()]
+
+ output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl")
+ processed_ids, outputs = load_existing_ids(output_path)
+
+ count = len(outputs)
+ # Run inference
+ new_outputs = []
+ for instance in tqdm(instances, disable=not dist.is_main()):
+ uuid = instance["id"]
+ sound_path = instance["sound"]
+
+ if sound_path in processed_ids:
+ continue # Skip if already processed
+ sound = llava.Sound(sound_path)
+ conversations = instance["conversations"]
+ question = conversations[0]["value"]
+
+ response = model.generate_content([sound, question], generation_config=generation_config)
+
+ print("response", response)
+
+ output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response}
+ new_outputs.append(output)
+ count = count +1
+ if count % 20 == 0:
+ # Gather and save outputs
+ if dist.size() > 1:
+ outputs_new = dist.gather(new_outputs, dst=0)
+ if dist.is_main():
+ outputs_new = list(itertools.chain(*outputs_new))
+ final_outputs = outputs + outputs_new
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
+ else:
+ final_outputs = outputs + new_outputs
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
+ if dist.size() > 1:
+ new_outputs = dist.gather(new_outputs, dst=0)
+ if not dist.is_main():
+ return
+ new_outputs = list(itertools.chain(*new_outputs))
+ final_outputs = outputs + new_outputs
+ io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs)
+
+if __name__ == "__main__":
+ main()
diff --git a/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62c020ee5b25a39cc6679b0f7f28235d1db3120c
Binary files /dev/null and b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc differ
diff --git a/llava/eval/mmmu_utils/eval_utils.py b/llava/eval/mmmu_utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c3cc6bc917ffa3f5af7ed7c2ec601ef8dbfbc1
--- /dev/null
+++ b/llava/eval/mmmu_utils/eval_utils.py
@@ -0,0 +1,61 @@
+# This file is originated from the official MMMU codebase:
+# https://github.com/MMMU-Benchmark/MMMU
+
+import random
+
+import numpy as np
+
+
+def parse_choice(response, all_choices, index2ans=None):
+ """
+ Parse the prediction from the generated response.
+ Return the predicted index e.g., A, B, C, D.
+ """
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # add space to avoid partial match
+
+ index_ans = True
+ ans_with_brack = False
+ candidates = []
+ for choice in all_choices: # e.g., (A) (B) (C) (D)
+ if f"({choice})" in response:
+ candidates.append(choice)
+ ans_with_brack = True
+
+ if len(candidates) == 0:
+ for choice in all_choices: # e.g., A B C D
+ if f" {choice} " in response:
+ candidates.append(choice)
+
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
+ if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None:
+ for index, ans in index2ans.items():
+ if ans.lower() in response.lower():
+ candidates.append(index)
+ index_ans = False # it's content ans.
+
+ if len(candidates) == 0: # still not get answer, randomly choose one.
+ pred_index = random.choice(all_choices)
+ elif len(candidates) > 1:
+ start_indexes = []
+ if index_ans:
+ if ans_with_brack:
+ for can in candidates:
+ index = response.rfind(f"({can})")
+ start_indexes.append(index) # -1 will be ignored anyway
+ # start_indexes = [generated_response.index(f'({can})') for can in candidates]
+ else:
+ for can in candidates:
+ index = response.rfind(f" {can} ")
+ start_indexes.append(index)
+ else:
+ for can in candidates:
+ index = response.lower().rfind(index2ans[can].lower())
+ start_indexes.append(index)
+ # get the last one
+ pred_index = candidates[np.argmax(start_indexes)]
+ else: # if only one candidate, use it.
+ pred_index = candidates[0]
+
+ return pred_index
diff --git a/llava/eval/registry_audio.yaml b/llava/eval/registry_audio.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8de55c73c3d12785044fb0cbd09c847eb548a514
--- /dev/null
+++ b/llava/eval/registry_audio.yaml
@@ -0,0 +1,93 @@
+Clotho-AQA-AQA:
+ tags:
+ - local
+Music-AVQA-AQA_All:
+ tags:
+ - local
+CochlScene-SceneClassification:
+ tags:
+ - local
+NSynth-Source:
+ tags:
+ - local
+NSynth-Instrument:
+ tags:
+ - local
+FSD50k-EventClassification:
+ tags:
+ - local
+Clotho-v2-AudioCaptioning:
+ tags:
+ - local
+audiocaps-AudioCaptioning:
+ tags:
+ - local
+ravdess-EmotionClassification:
+ tags:
+ - local
+GTZAN-GenreClassification:
+ tags:
+ - local
+UrbanSound8K-EventClassification:
+ tags:
+ - local
+Medley-solos-DB-InstrClassification:
+ tags:
+ - local
+ESC50-EventClassification:
+ tags:
+ - local
+CREMA-D-EmotionClassification:
+ tags:
+ - local
+IEMOCAP-EmotionClassification:
+ tags:
+ - local
+MELD-EmotionClassification:
+ tags:
+ - local
+MELD-SentimentClassification:
+ tags:
+ - local
+MMAU:
+ tags:
+ - local
+AudioEntailmentQA:
+ tags:
+ - local
+SPGI-ASR:
+ tags:
+ - local
+SWBD-ASR:
+ tags:
+ - local
+LibriSpeech-ASR-clean:
+ tags:
+ - local
+LibriSpeech-ASR-other:
+ tags:
+ - local
+VoxPopuli-ASR:
+ tags:
+ - local
+Europarl-ASR:
+ tags:
+ - local
+CV-ASR:
+ tags:
+ - local
+GigaSpeech-ASR:
+ tags:
+ - local
+CompA-R-AQA:
+ tags:
+ - local
+MuschoMusicQA:
+ tags:
+ - local
+CMM:
+ tags:
+ - local
+AIR-Bench:
+ tags:
+ - local
diff --git a/llava/media.py b/llava/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bcc0aa44034212a4649b9ef256eb5d995e48380
--- /dev/null
+++ b/llava/media.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+__all__ = ["Media", "File", "Image", "Video", "Speech", "Sound"]
+
+
+class Media:
+ pass
+
+
+class File(Media):
+ def __init__(self, path: str) -> None:
+ self.path = path
+
+
+class Image(File):
+ pass
+
+
+class Video(File):
+ pass
+
+
+class Speech(File):
+ pass
+
+class Sound(File):
+ pass
\ No newline at end of file
diff --git a/llava/mm_utils.py b/llava/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4a23d338f9a4440aeaa6f44f7f5fb0d89093fe
--- /dev/null
+++ b/llava/mm_utils.py
@@ -0,0 +1,641 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
+
+import base64
+import os
+import tempfile
+from io import BytesIO
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import StoppingCriteria
+
+from pydub import AudioSegment
+from torchvision import transforms
+import soundfile as sf
+from librosa import resample as librosa_resample
+import whisper
+import random
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
+DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
+def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
+ import cv2
+
+ if fps == None or frame_count == None:
+ # if one of fps or frame_count is None, still recompute
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if fps == 0 or frame_count == 0:
+ print(f"Video file not found. return empty images. {video_file_name}")
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * num_frames, 0, [0.]
+
+ duration = frame_count / fps
+ frame_interval = frame_count // num_frames
+ if frame_interval == 0 and frame_count <= 1:
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * num_frames, 0, [0.]
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
+
+ images = []
+ count = 0
+ success = True
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
+ frame_times = [frame / fps for frame in frame_indices]
+ while success:
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
+ if frame_count >= num_frames:
+ success, frame = vidcap.read()
+ if count in frame_indices:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except BaseException:
+ continue
+ if len(images) >= num_frames:
+ return images, num_frames, frame_times
+ count += 1
+ else:
+ # Left padding frames if the video is not long enough
+ success, frame = vidcap.read()
+ if success:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except BaseException:
+ continue
+ count += 1
+ else:
+ break
+ if len(images) == 0:
+ raise ValueError("Did not find enough frames in the video. return empty image.")
+
+ return images, len(images), frame_times
+
+
+def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
+ """
+ num_frames is the max number of frames the model can support.
+ frame_count is the number of frames in the input video.
+ max_fps is the max FPS of the model can support.
+ fps is the fps of the input video.
+ """
+
+ import random
+
+ import cv2
+
+ if fps == None or frame_count == None:
+ # if one of fps or frame_count is None, still recompute
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ if fps == 0 or frame_count == 0:
+ print(f"Video file not found. return empty images. {video_file_name}")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+
+ duration = frame_count / fps
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
+ # If the video is too long (longer than max_fps and num_frames can support),
+ # we will use lower fps to sample frames.
+ if duration >= num_frames / max_fps:
+ frame_interval = frame_count // num_frames
+
+ # If the video is too short, we will skip the video if there is only one frame.
+ if frame_interval == 0 and frame_count <= 1:
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+
+ images = []
+ count = 0
+ success = True
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
+ frame_times = [frame / fps for frame in frame_indices]
+ while success:
+ if frame_count >= num_frames:
+ # success, frame = vidcap.read()
+ if count in frame_indices:
+ success, frame = vidcap.read()
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ # print("Failed to read frame:", count)
+ continue
+ if len(images) >= num_frames:
+ return images, num_frames, frame_times
+ else:
+ success = vidcap.grab()
+ count += 1
+ else:
+ # Left padding frames if the video is not long enough
+ success, frame = vidcap.read()
+ if success:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ # print("Failed to read frame:", count)
+ continue
+ count += 1
+ else:
+ break
+ else:
+ frames_required = int(duration * max_fps)
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
+ if frames_required == 0:
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+ elif frames_required == 1:
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
+ images = []
+ count = 0
+ looked = 0
+ success = True
+
+ while success:
+ success, frame = vidcap.read()
+ if success and (looked in frame_indices):
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ continue
+ count += 1
+ looked += 1
+ frame_times = [frame / fps for frame in frame_indices]
+
+ if len(images) == 0:
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+ else:
+ return images, len(images), frame_times
+
+
+def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
+ """
+ Extract frames from a video using OpenCV.
+
+ Args:
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
+ frames (int): Number of frames to extract from the video.
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
+
+ Returns:
+ list: List of PIL Images extracted from the video.
+
+ Raises:
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
+ """
+ import cv2
+ if isinstance(vpath_or_bytesio, str):
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
+ if max_fps > 0.0:
+ return get_frame_from_vcap_with_fps(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
+ )
+ return get_frame_from_vcap(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
+ )
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
+ # assuming mp4
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
+ temp_video.write(vpath_or_bytesio.read())
+ temp_video_name = temp_video.name
+ vidcap = cv2.VideoCapture(temp_video_name)
+ if max_fps > 0.0:
+ return get_frame_from_vcap_with_fps(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
+ )
+ return get_frame_from_vcap(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
+ )
+ else:
+ raise NotImplementedError(type(vpath_or_bytesio))
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ """
+ Expand the given PIL image to a square shape by adding padding.
+
+ Parameters:
+ - pil_img: The PIL image to be expanded.
+ - background_color: The color of the padding to be added.
+
+ Returns:
+ - The expanded PIL image.
+
+ If the image is already square, it is returned as is.
+ If the image is wider than it is tall, padding is added to the top and bottom.
+ If the image is taller than it is wide, padding is added to the left and right.
+ """
+ width, height = pil_img.size
+ if pil_img.mode == "L":
+ background_color = background_color[0]
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float("inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+
+def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
+
+ processed_images = []
+
+ ##########################################################################################
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
+ ##########################################################################################
+
+ for scale in s2_scales[:-1]:
+ target_width = image_size * (scale // s2_scales[0])
+ target_height = image_size * (scale // s2_scales[0])
+ blocks = (scale // s2_scales[0]) ** 2
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+
+ ##########################################################################################
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
+ ##########################################################################################
+
+ # calculate the existing image aspect ratio
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
+
+
+
+def dynamic_s2_process_images_and_prompt(images, data_args, image_folder=None):
+ idx = 0
+ all_images = []
+ all_block_size = []
+ for img in images:
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
+ all_images.append(processed_images)
+ all_block_size.append(block_size)
+ idx += 2
+ if all_images:
+ all_images = torch.cat(all_images)
+ else:
+ all_images = None
+ return all_images, all_block_size
+
+
+def process_image(
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
+):
+ processor = data_args.image_processor
+ if isinstance(image_file, str):
+ if image_folder is not None:
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ else:
+ # image is stored in bytearray
+ image = image_file
+ image = image.convert("RGB")
+ if hasattr(data_args.image_processor, "crop_size"):
+ # CLIP vision tower
+ crop_size = data_args.image_processor.crop_size
+ else:
+ # SIGLIP vision tower
+ assert hasattr(data_args.image_processor, "size")
+ crop_size = data_args.image_processor.size
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
+ assert crop_size["height"] == crop_size["width"]
+ images, block_size = dynamic_s2_preprocess(
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
+ )
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
+ return torch.stack(images), block_size
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
+ assert crop_size["height"] == crop_size["width"]
+ if max_tiles is not None:
+ max_num = max_tiles
+ else:
+ max_num = data_args.max_tiles
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
+ return torch.stack(images)
+
+ if data_args.image_aspect_ratio == "resize":
+ image = image.resize((crop_size["width"], crop_size["height"]))
+ if data_args.image_aspect_ratio == "pad":
+
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ else:
+ # Using default behavior of the vision encoder
+ # For CLIP, default is central crop
+ # For Radio, default is central crop
+ # For Siglip, default is resize
+ # For InternVIT, default is resize
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ return image
+
+def get_num_windows(T, sr, max_num_window=5):
+
+ window_length = int(30.0 * sr)
+ window_overlap = int(0.0 * sr)
+ max_num_window = max_num_window
+
+ num_windows = 1
+ if T <= window_length:
+ num_windows = 1
+ full_length = window_length
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
+ num_windows = max_num_window
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
+ else:
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
+
+ return num_windows, full_length
+
+def load_audio(file_path, target_sr=16000, duration=30.0, start=0.0):
+ if file_path.endswith('.mp3'):
+ audio = AudioSegment.from_file(file_path)
+ if len(audio) > (start + duration) * 1000:
+ audio = audio[start * 1000:(start + duration) * 1000]
+
+ if audio.frame_rate != target_sr:
+ audio = audio.set_frame_rate(target_sr)
+
+ if audio.channels > 1:
+ audio = audio.set_channels(1)
+
+ data = np.array(audio.get_array_of_samples())
+ if audio.sample_width == 2:
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
+ elif audio.sample_width == 4:
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
+ else:
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
+
+ else:
+ with sf.SoundFile(file_path) as audio:
+ original_sr = audio.samplerate
+ channels = audio.channels
+
+ max_frames = int((start + duration) * original_sr)
+
+ audio.seek(int(start * original_sr))
+ frames_to_read = min(max_frames, len(audio))
+ data = audio.read(frames_to_read)
+
+ if data.max() > 1 or data.min() < -1:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ if original_sr != target_sr:
+ if channels == 1:
+ data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
+ else:
+ data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
+ else:
+ if channels != 1:
+ data = data.T[0]
+
+ if data.min() >= 0:
+ data = 2 * data / abs(data.max()) - 1.0
+ else:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ assert len(data.shape) == 1, data.shape
+ return data
+
+def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
+ model_cfg.image_processor = image_processor
+ new_images = [
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
+ for image in images
+ ]
+
+ if all(x.shape == new_images[0].shape for x in new_images):
+ if len(new_images[0].shape) == 4:
+ new_images = torch.cat(new_images, dim=0)
+ elif len(new_images[0].shape) == 3:
+ new_images = torch.stack(new_images, dim=0)
+ else:
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
+ else:
+ raise ValueError("The shape of images in new_images is different!")
+ return new_images
+
+def process_sounds(sounds):
+ sounds = torch.tensor(sounds)
+ return sounds
+
+def process_sound_masks(masks):
+ masks = torch.tensor(masks[0])
+ return masks
+
+def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
+
+def is_gemma_tokenizer(tokenizer):
+ return "gemma" in tokenizer.__class__.__name__.lower()
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ outputs = []
+ for i in range(output_ids.shape[0]):
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
+ return all(outputs)
diff --git a/llava/model/FloatPointQuantizeTorch.py b/llava/model/FloatPointQuantizeTorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b01c6f8ad008cdcd01c51a85399a194032df05b1
--- /dev/null
+++ b/llava/model/FloatPointQuantizeTorch.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+
+import torch
+
+
+def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
+ expo = torch.floor(torch.log2(x_abs))
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / torch.exp2(expo)
+
+ mant_int = torch.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ if stochastic:
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
+ mant_frac.add_(noise)
+ mant_frac = torch.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * (2**expo) * mant_q
+ y = y.to(x)
+
+ return y
+
+
+def floatExM0_quantize_torch(x, e_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
+ expo = torch.log2(x_abs)
+ if stochastic:
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
+ expo.add(noise)
+ log_bias = math.log2(4 / 3) - 1 / 2
+ expo.add(torch.ones_like(expo) * log_bias)
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
+ expo = torch.round(expo)
+
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
+ y = y.to(x)
+
+ return y
+
+
+def Dynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
+ y = y * zero_mask
+ y = y.to(x)
+ return y
+
+
+def ZeroDynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ y = y.to(x)
+ return y
diff --git a/llava/model/FloatPointQuantizeTriton.py b/llava/model/FloatPointQuantizeTriton.py
new file mode 100644
index 0000000000000000000000000000000000000000..e619c3739ec968c92460b3d8d9efd2736fd07d44
--- /dev/null
+++ b/llava/model/FloatPointQuantizeTriton.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+import struct
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+segment_size = 1024**3
+
+
+def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
+ x_ori_shape = x.shape
+ x = x.view(-1)
+
+ n_elements = x.numel()
+
+ if n_elements <= segment_size:
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.empty_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ torch.cuda.synchronize()
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+ else: # Triton will break when x.numel > 2 * 1024 ** 3
+ num_segments = n_elements // segment_size + 1
+ split_size = [segment_size] * (num_segments - 1) + [n_elements - segment_size * (num_segments - 1)]
+ x_list = x.split(split_size)
+ y_list = []
+ del x
+
+ for x in x_list:
+ n_elements = x.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.empty_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ torch.cuda.synchronize()
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+
+ y_list.append(y)
+ y = torch.concat(y_list)
+ del y_list
+
+ y = y.reshape(x_ori_shape)
+ return y
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_warps=4,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_quantize_kernel(
+ x_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ # mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_warps=4,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_stochastic_quantize_kernel(
+ x_ptr,
+ noise_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ noise = tl.load(noise_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
diff --git a/llava/model/__init__.py b/llava/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5494e136efbc55e4813a40184308c26fe5949b
--- /dev/null
+++ b/llava/model/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel
+
+# FP8 related comments, development in progress (PI: ligeng zhu, haochen xi)
+# NOTE: VLM + LLM
+# from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
+# NOTE: Linear -> fp8, similar to transformer engine
+# from .language_model.qllama import QLlamaConfig, QLlamaForCausalLM, QLlamaModel
+# NOTE: Linear + Activation -> fp8, haochen's iclr version
+# from .language_model.qmemllama import QMemLlamaConfig, QMemLlamaForCausalLM, QMemLlamaModel
+"""
+TODO:
+ linear(weights):
+ simulated fp8: done
+ real fp8: in-progress (code already implmented)
+ activation:
+ simulated fp8: done
+ real fp8: in-progress (still coding)
+ optimizers:
+ current VILA: bf16
+ simulated fp8: done
+ real fp8 + fsdp (single node): done
+ real fp8 + fsdp (multiple node): in-progress
+1. linear fp8
+2. activation fp8
+3. fp8 infernce example (load directly from a fp8 and fwd)
+4. bind fp8 related configs to QLlamaConfig {"coat_fp8_args": {}}
+"""
+from .language_model.fp8linearqwen2 import FP8LinearQwen2Config, FP8LinearQwen2Model
+from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..767d0612da9a1fd77058e3d004d16e45d572e3ca
--- /dev/null
+++ b/llava/model/apply_delta.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in [
+ "model.mm_projector.weight",
+ "model.mm_projector.bias",
+ ], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in [
+ "model.embed_tokens.weight",
+ "lm_head.weight",
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/llava/model/builder.py b/llava/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c715ddf2eaeaa59d9807b6fdc3d31a71005067c
--- /dev/null
+++ b/llava/model/builder.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PretrainedConfig
+
+from llava.model import LlavaLlamaModel
+from llava.model.utils import is_mm_model
+
+
+def load_pretrained_model(
+ model_path,
+ model_name,
+ model_base=None,
+ load_8bit=False,
+ load_4bit=False,
+ device_map="auto",
+ device="cuda",
+ **kwargs,
+):
+ kwargs = {"device_map": device_map, **kwargs}
+
+ if device != "cuda":
+ kwargs["device_map"] = {"": device}
+
+ if load_8bit:
+ kwargs["load_in_8bit"] = True
+ elif load_4bit:
+ kwargs["load_in_4bit"] = True
+ kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ )
+ else:
+ kwargs["torch_dtype"] = torch.float16
+ # kwargs["torch_dtype"] = torch.bfloat16
+
+ if is_mm_model(model_path):
+ # Load LLaVA model
+ ## TODO @yunhao: mind fixing lora
+ if "lora" in model_name.lower() and model_base is None:
+ warnings.warn(
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
+ )
+ if ("lora" in model_name.lower() or "dora" in model_name.lower()) and model_base is not None:
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ print(lora_cfg_pretrained)
+ print("Loading LLaVA from base model...")
+ config = AutoConfig.from_pretrained(model_base)
+ prepare_config_for_eval(config, kwargs)
+ model = LlavaLlamaModel.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
+ tokenizer = model.tokenizer
+ token_num, tokem_dim = model.llm.lm_head.out_features, model.llm.lm_head.in_features
+ if model.llm.lm_head.weight.shape[0] != token_num:
+ model.llm.lm_head.weight = torch.nn.Parameter(
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
+ )
+ model.llm.embed_tokens.weight = torch.nn.Parameter(
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
+ )
+
+ print("Loading additional LLaVA weights...")
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
+ non_lora_trainables = torch.load(
+ os.path.join(model_path, "non_lora_trainables.bin"),
+ map_location="cpu",
+ )
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
+ return torch.load(cache_file, map_location="cpu")
+
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
+ non_lora_trainables = {
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
+ non_lora_trainables = {
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+
+ print("Loading LoRA weights...")
+ model = PeftModel.from_pretrained(model, model_path)
+ print("Merging LoRA weights...")
+ model = model.merge_and_unload()
+ print("Model is loaded...")
+ else:
+ config = AutoConfig.from_pretrained(model_path)
+ config.resume_path = model_path
+ prepare_config_for_eval(config, kwargs)
+ model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs)
+ tokenizer = model.tokenizer
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print("Convert to FP16...")
+ model.to(torch.float16)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, legacy=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ model.eval()
+ image_processor = None
+ if is_mm_model(model_path):
+ model.resize_token_embeddings(len(tokenizer))
+
+ if hasattr(model.llm.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
+
+
+def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
+ try:
+ # compatible with deprecated config convention
+ if getattr(config, "vision_tower_cfg", None) is None:
+ config.vision_tower_cfg = config.mm_vision_tower
+ except AttributeError:
+ raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
+
+ config.model_dtype = kwargs.pop("torch_dtype").__str__()
diff --git a/llava/model/coat/activation/__init__.py b/llava/model/coat/activation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69923298db60fa8275bf22895fd5c54cf102e9b9
--- /dev/null
+++ b/llava/model/coat/activation/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..04acd2fb73f707152839e1a97acb67e6f8df873b
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+
+import torch
+
+
+def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
+ expo = torch.floor(torch.log2(x_abs))
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / torch.exp2(expo)
+
+ mant_int = torch.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ if stochastic:
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
+ mant_frac.add_(noise)
+ mant_frac = torch.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * (2**expo) * mant_q
+ y = y.to(x)
+
+ return y
+
+
+def floatExM0_quantize_torch(x, e_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
+ expo = torch.log2(x_abs)
+ if stochastic:
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
+ expo.add(noise)
+ log_bias = math.log2(4 / 3) - 1 / 2
+ expo.add(torch.ones_like(expo) * log_bias)
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
+ expo = torch.round(expo)
+
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
+ y = y.to(x)
+
+ return y
+
+
+def Dynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
+ y = y * zero_mask
+ y = y.to(x)
+ return y
+
+
+def ZeroDynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ y = y.to(x)
+ return y
diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py
new file mode 100644
index 0000000000000000000000000000000000000000..db85e7fc42dc54539f54400ccff280402e126f0d
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py
@@ -0,0 +1,181 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+import struct
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+
+def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
+ n_elements = x.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.zeros_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+
+ return y
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_quantize_kernel(
+ x_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ # mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_stochastic_quantize_kernel(
+ x_ptr,
+ noise_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ noise = tl.load(noise_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
diff --git a/llava/model/coat/activation/fake_quantization/quantize_function.py b/llava/model/coat/activation/fake_quantization/quantize_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b18d45603ecd3f6564a8dbac3810d8ce701fbb6
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/quantize_function.py
@@ -0,0 +1,239 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+
+import torch
+
+from .FloatPointQuantizeTorch import *
+from .FloatPointQuantizeTriton import *
+
+
+def block_cut(input, row_block, column_block, pad_block=False):
+ # print(input.shape)
+ original_shape = input.shape
+ # input tensor shape is M * N
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ input = torch.nn.functional.pad(
+ input, (0, col_pad, 0, row_pad), "constant", 0
+ ) # refer to torch's doc to see why
+ M, N = input.shape[0], input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
+ assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
+ # print(input.shape)
+ input = (
+ input.reshape(row_num, row_block, column_num, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * column_num, row_block, column_block)
+ )
+ # print(input.shape)
+ return input
+
+
+def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
+ if len(origin_input.shape) > 2:
+ flatten_input = origin_input.reshape(-1, origin_input.shape[2])
+ elif len(origin_input.shape) == 2:
+ flatten_input = origin_input
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut")
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
+ M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ input = (
+ input.reshape(row_num, column_num, row_block, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * row_block, column_num * column_block)
+ )
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+ input = input[:M, :N]
+
+ if len(origin_input.shape) > 2:
+ input = input.reshape(origin_input.shape)
+ elif len(origin_input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block reshape")
+
+ return input
+
+
+def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
+ Binput = block_cut(input, row_block, column_block)
+ Binput = Binput.to(torch.float32)
+
+ for n in range(Binput.shape[0]):
+ unique_values = len(torch.unique(Binput[n, :, :]))
+ if unique_values > 256:
+ if necessary:
+ raise ValueError(f"{layer_type} contains more than 256 unique values.")
+ else:
+ return False
+ return True
+
+
+def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
+ Quant_fn = SymmQuantizer
+ return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)
+
+
+def extract_bit(string):
+ match = re.match(r"INT(\d+)", string) # INT8
+ if match:
+ return "integer", int(match.group(1)), None
+ match = re.match(r"E(\d+)M(\d+)", string) # E4M3 / E5M2
+ if match:
+ Ebit, Mbit = int(match.group(1)), int(match.group(2))
+ if Ebit == 1:
+ return "integer", Mbit + 1, None
+ if Mbit == 0:
+ return "floatExM0", int(match.group(1)), 0
+ return "floatExMy", int(match.group(1)), int(match.group(2))
+ match = re.match(r"DE(\d+)", string)
+ if match:
+ return "Dynamic", int(match.group(1)), None
+ match = re.match(r"ZeroD(\d+)", string)
+ if match:
+ return "ZeroDynamic", int(match.group(1)), None
+ raise ValueError(f"{string} data format is not supported")
+
+
+class SymmQuantizer(torch.autograd.function.InplaceFunction):
+ @staticmethod
+ def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
+ with torch.no_grad():
+ absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon
+
+ if bits == "100" or not apply_quantize:
+ return input, input, torch.ones_like(absmax_per_block)
+ elif bits == "FP32":
+ return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
+ elif bits == "FP16":
+ return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
+ elif bits == "BF16":
+ return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
+ else:
+ QuantType, bit1, bit2 = extract_bit(bits)
+ if not symm:
+ bit1 = bit1 + 1 # pretend to be asymmtric
+
+ if QuantType == "integer":
+ Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
+ elif QuantType == "floatExMy":
+ Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
+ 2 ** (2 ** (bit1 - 1))
+ )
+ if bit1 == 4 and bit2 == 3: # E4M3
+ Qn, Qp = -448, 448
+ if bit1 == 5 and bit2 == 2: # E5M2
+ Qn, Qp = -57344, 57344
+ elif QuantType == "floatExM0":
+ Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
+ elif QuantType == "Dynamic":
+ Qn, Qp = -1, 1
+ elif QuantType == "ZeroDynamic":
+ Qn, Qp = -1, 1
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+ scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
+ scale_per_block = scale_per_block.to(input)
+
+ Qinput = input / scale_per_block
+
+ if QuantType == "integer":
+ if stochastic:
+ noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
+ Qinput.add_(noise)
+ Qinput.clamp_(Qn, Qp).round_()
+ elif QuantType == "floatExMy":
+ # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
+ Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
+ elif QuantType == "floatExM0":
+ Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+
+ RQinput = Qinput * scale_per_block
+
+ if input.dtype != Qinput.dtype:
+ print(
+ f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
+ file=open("debug.txt", "a"),
+ )
+ import IPython
+
+ IPython.embed()
+ return RQinput, Qinput, scale_per_block
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None, None, None, None, None
diff --git a/llava/model/coat/activation/fake_quantization/utils.py b/llava/model/coat/activation/fake_quantization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc4ec2dae1ef7be1dc7c82c79439b7091d49feb6
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/utils.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def list_has_common_element(list1, list2):
+ set1 = set(list1)
+ set2 = set(list2)
+ return len(set1.intersection(set2)) > 0
+
+
+def calculate_scale_num(input, row_block, col_block):
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if col_block == -1:
+ col_block = N
+
+ return input.numel() / (row_block * col_block)
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+def format_string_with_condition(
+ input_string,
+ condition_config,
+ symm,
+ bits,
+ blocksize_config,
+ input_pad=20,
+):
+ padded_string = input_string.ljust(input_pad)
+ output_string = padded_string
+
+ for k, v in condition_config.items():
+ if v:
+ output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6)
+ else:
+ output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6)
+
+ output_string = output_string + f"Symm {symm}".ljust(10)
+
+ for k, v in bits.items():
+ output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10)
+ for k, v in blocksize_config.items():
+ output_string += f"{k}: {v}".ljust(15)
+
+ return output_string
+
+
+def print_warning(sentence):
+ print("*" * (len(sentence) + 4))
+ print(f"* {sentence} *")
+ print("*" * (len(sentence) + 4))
+
+
+def check_nan_inf(tensor, check_nan, check_inf):
+ if check_nan:
+ contain_nan = torch.isnan(tensor).any()
+ else:
+ contain_nan = False
+ if check_inf:
+ contain_inf = torch.isinf(tensor).any()
+ else:
+ contain_inf = False
+ return contain_nan, contain_inf
+
+
+def move_torch_to_numpy(tensor):
+ if tensor is None:
+ return None
+
+ if tensor.is_cuda:
+ tensor = tensor.cpu()
+ return tensor.detach().float().numpy()
+
+
+def flatten_to_1d(tensor):
+ if tensor is None:
+ return None
+
+ return tensor.reshape(-1)
diff --git a/llava/model/coat/activation/models/_fp8_quantization_config.py b/llava/model/coat/activation/models/_fp8_quantization_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..579df908203575abf8da9565995640c4ae459e3b
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8_quantization_config.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+
+from transformers import PretrainedConfig
+
+
+@dataclass
+class QuantizationConfig:
+ quantize_model: str = "false"
+ symm: bool = True
+ epsilon: float = 1e-10
+ fabit: str = "E4M3"
+ fwbit: str = "E4M3"
+ fobit: str = "E4M3"
+ babit: str = "E5M2"
+ bwbit: str = "E5M2"
+ bobit: str = "E5M2"
+ qchoice: str = "none"
+ group_size: int = -1
+ pad_to_multiple_of: int = 0
+ weight_memory_efficient: bool = True
+
+ # Legacy
+ row_blocksize: int = -1
+ col_blocksize: int = -1
+
+ def __init__(
+ self,
+ quantize_model: str = "false",
+ symm: bool = True,
+ epsilon: float = 1e-10,
+ fabit: str = "E4M3",
+ fwbit: str = "E4M3",
+ fobit: str = "E4M3",
+ babit: str = "E5M2",
+ bwbit: str = "E5M2",
+ bobit: str = "E5M2",
+ qchoice: str = "none",
+ group_size: int = -1,
+ pad_to_multiple_of: int = 0,
+ weight_memory_efficient: bool = True,
+ row_blocksize: int = -1,
+ col_blocksize: int = -1,
+ **kwargs,
+ ):
+ super().__init__()
+ self.quantize_model = quantize_model
+ self.symm = symm
+ self.epsilon = epsilon
+ self.fabit = fabit
+ self.fwbit = fwbit
+ self.fobit = fobit
+ self.babit = babit
+ self.bwbit = bwbit
+ self.bobit = bobit
+ self.qchoice = qchoice
+ self.group_size = group_size
+ self.pad_to_multiple_of = pad_to_multiple_of
+ self.weight_memory_efficient = weight_memory_efficient
+
+ self.row_blocksize = row_blocksize
+ self.col_blocksize = col_blocksize
diff --git a/llava/model/coat/activation/models/_fp8_weightcache.py b/llava/model/coat/activation/models/_fp8_weightcache.py
new file mode 100644
index 0000000000000000000000000000000000000000..85b1c504c74b83299181020e1275133428f93e1d
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8_weightcache.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+
+from ..real_quantization import fp8_division_transpose
+
+
+class FP8CacheWeightModule(nn.Module):
+ def __init__(self, config, qargs, layer_id):
+ super().__init__()
+ self.config = config
+ self.qargs = qargs
+ self.layer_id = layer_id
+
+ def prepare_weight(self, weight, weight_name, is_first_microbatch):
+ if is_first_microbatch:
+ if self.qargs.weight_memory_efficient:
+ # print(f"{weight_name} uses first microbatch")
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
+ )
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
+ return weight_fp8, weight_fp8_t, weight_s
+ else:
+ # print(f"{weight_name} uses first microbatch")
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
+ )
+ setattr(self, f"{weight_name}_fp8", weight_fp8)
+ setattr(self, f"{weight_name}_fp8_t", weight_fp8_t)
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
+ return weight_fp8, weight_fp8_t, weight_s
+ else:
+ if self.qargs.weight_memory_efficient:
+ return getattr(self, f"{weight_name}_fp8_scale")
+ else:
+ return (
+ getattr(self, f"{weight_name}_fp8"),
+ getattr(self, f"{weight_name}_fp8_t"),
+ getattr(self, f"{weight_name}_fp8_scale"),
+ )
+
+ def forward(self, x):
+ pass
diff --git a/llava/model/coat/activation/models/_fp8manager.py b/llava/model/coat/activation/models/_fp8manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..5270acf79adc7d17838959b0c679cba3b43c5bee
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8manager.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+class FP8Manager:
+ """Class to keep track of and manipulate the global
+ FP8 state at different stages of execution.
+ """
+
+ is_first_microbatch = False
diff --git a/llava/model/coat/activation/models/coat_llama.py b/llava/model/coat/activation/models/coat_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6763c027ac05d04c1b66b483cf886c0e1f7fab
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_llama.py
@@ -0,0 +1,1479 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import os
+from fnmatch import fnmatch
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaForCausalLM,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ _prepare_4d_causal_attention_mask_with_cache_position,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+
+# FP8 related
+from ._fp8_quantization_config import QuantizationConfig
+from ._fp8_weightcache import FP8CacheWeightModule
+from ._fp8manager import FP8Manager
+
+logger = logging.get_logger(__name__)
+
+
+class CoatLlamaConfig(LlamaConfig):
+ model_type = "fp8_llama"
+
+
+class CoatLlamaBeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+ return _CoatLlamaBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.k_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.v_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _CoatLlamaBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _CoatLlamaBeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_v_wg,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class CoatLlamaAfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+
+ return _CoatLlamaAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ None,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatLlamaAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _CoatLlamaAfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4 is None # memory efficient
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class CoatLlamaMLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch)
+
+ return _CoatLlamaMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatLlamaMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _CoatLlamaMLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None and weight2 is None and weight3 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class LlamaAttentionWithoutLinear(nn.Module):
+ """
+ Remove the Q/K/V/O projection layer in LlamaAttention module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2WithoutLinear(LlamaAttentionWithoutLinear):
+ """
+ Remove the Q/K/V/O projection layer in LlamaFlashAttention2 module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = query_states.size()
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaSdpaAttentionWithoutLinear(LlamaAttentionWithoutLinear):
+ """
+ Remove the Q/K/V/O projection layer in LlamaSdpaAttention module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ return attn_output, None, past_key_value
+
+
+COAT_LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttentionWithoutLinear,
+ "flash_attention_2": LlamaFlashAttention2WithoutLinear,
+ "sdpa": LlamaSdpaAttentionWithoutLinear,
+}
+
+
+class CoatLlamaDecoderLayer(nn.Module):
+ def __init__(self, config: CoatLlamaConfig, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = COAT_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = CoatLlamaBeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = CoatLlamaAfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = CoatLlamaMLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): BF16 input to the layer of shape `(batch, seq_len, embed_dim)`
+ quant_hidden_states (`torch.float8_e4m3fn`): FP8 input to the layer of shape `(batch, seq_len, embed_dim)`
+ scale_hidden_states (`torch.bfloat16`): BF16 scaling factor to the layer of shape `(batch, seq_len, embed_dim // group_size)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual, query_states, key_states, value_states = self.BeforeAttention(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states)
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class CoatLlamaPreTrainedModel(PreTrainedModel):
+ config_class = CoatLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class CoatLlamaModel(CoatLlamaPreTrainedModel):
+ """
+ Coat Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CoatLlamaDecoderLayer`]
+
+ Args:
+ config: CoatLlamaConfig
+ """
+
+ def __init__(self, config: CoatLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [CoatLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+
+
+class CoatLlamaForCausalLM(CoatLlamaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = CoatLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+# TODO
+# class LlamaForSequenceClassification(LlamaPreTrainedModel):
+
+# class LlamaForQuestionAnswering(LlamaPreTrainedModel):
+
+# class LlamaForTokenClassification(LlamaPreTrainedModel):
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
+
+
+AutoConfig.register("fp8_llama", CoatLlamaConfig)
+AutoModel.register(CoatLlamaConfig, CoatLlamaModel)
+AutoModelForCausalLM.register(CoatLlamaConfig, CoatLlamaForCausalLM)
diff --git a/llava/model/coat/activation/models/coat_llama_convert_from_hf.py b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02da9e1b4211f8becaa2a54dc30a46fcf88185b
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import os
+from dataclasses import asdict, dataclass, field
+from typing import Optional
+
+import torch
+import transformers
+from coat.activation.models._fp8_quantization_config import QuantizationConfig
+from coat.activation.models.coat_llama import CoatLlamaConfig, CoatLlamaForCausalLM, make_state_dict_compatible
+from transformers import AutoConfig, AutoModelForCausalLM
+
+
+@dataclass
+class ConvertArguments:
+ model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"})
+ save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"})
+ cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"})
+
+
+def download_and_convert_llama(convert_args: ConvertArguments, quantization_args: QuantizationConfig):
+ """
+ Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`,
+ and saves the converted model.
+
+ Args:
+ model_name (str): The model name or path to download the LLaMA model.
+ save_path (str): The path where the converted model weights will be saved.
+ cache_dir (Optional[str]): Directory to cache the model. Defaults to None.
+
+ Returns:
+ None
+ """
+ model_name = convert_args.model_name
+ save_path = convert_args.save_path
+ cache_dir = convert_args.cache_dir
+
+ # Step 1: Download the original LLaMA model
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 2: Initialize the model configuration for FP8 or other custom config
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 3: Apply make_state_dict_compatible to convert weights
+ compatible_state_dict = make_state_dict_compatible(model.state_dict())
+
+ # Step 4: Create a new model instance with compatible configuration
+ fp8_config = CoatLlamaConfig(**config.to_dict())
+ fp8_config.coat_fp8_args = asdict(quantization_args)
+
+ converted_model = AutoModelForCausalLM.from_config(fp8_config)
+ converted_model.load_state_dict(compatible_state_dict)
+
+ # Step 5: Save the converted model and configuration using save_pretrained
+ os.makedirs(save_path, exist_ok=True)
+ converted_model.save_pretrained(save_path)
+ print(f"Converted model saved at {save_path}")
+
+
+if __name__ == "__main__":
+ # Parse command-line arguments
+ parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8
+ convert_args, quantization_args = parser.parse_args_into_dataclasses()
+
+ # Call the function with parsed arguments
+ download_and_convert_llama(convert_args, quantization_args)
diff --git a/llava/model/coat/activation/models/coat_olmo.py b/llava/model/coat/activation/models/coat_olmo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d2cdedd73c56a68396aebdab1a2762876abc53
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_olmo.py
@@ -0,0 +1,1942 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Adapted from
+[MosaiclML](https://github.com/mosaicml/examples.git) and
+[minGPT](https://github.com/karpathy/minGPT.git)
+"""
+
+from __future__ import annotations
+
+import logging
+import math
+import sys
+from abc import abstractmethod
+from collections import defaultdict
+from functools import partial
+from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple, cast
+
+import torch
+import torch.backends.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from olmo.aliases import PathOrStr
+from olmo.beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
+from olmo.config import (
+ ActivationCheckpointingStrategy,
+ ActivationType,
+ BlockType,
+ CheckpointType,
+ FSDPWrapStrategy,
+ InitFnType,
+ LayerNormType,
+ ModelConfig,
+ QuantActivationConfig,
+ ShardedCheckpointerType,
+ TrainConfig,
+)
+from olmo.exceptions import OLMoConfigurationError
+from olmo.initialization import init_normal
+from olmo.model import (
+ Activation,
+ BufferCache,
+ Dropout,
+ LayerNorm,
+ LayerNormBase,
+ OLMo,
+ OLMoBlock,
+ OLMoBlockGroup,
+ OLMoGenerateOutput,
+ OLMoOutput,
+ RMSLayerNorm,
+ RotaryEmbedding,
+ _non_meta_init_device,
+ activation_checkpoint_function,
+ alibi_attention_bias,
+ causal_attention_bias,
+ get_causal_attention_bias,
+ should_checkpoint_block,
+)
+from olmo.torch_util import ensure_finite_, get_cumulative_document_lengths
+from torch import einsum
+
+from ..real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ._fp8_weightcache import FP8CacheWeightModule
+from ._fp8manager import FP8Manager
+
+if sys.version_info.minor > 8:
+ from collections.abc import MutableMapping
+elif sys.version_info.minor == 8:
+ from typing import MutableMapping
+else:
+ raise SystemExit("This script supports Python 3.8 or higher")
+
+__all__ = [
+ "LayerNormBase",
+ "LayerNorm",
+ "RMSLayerNorm",
+ "RotaryEmbedding",
+ "Activation",
+ "GELU",
+ "ReLU",
+ "SwiGLU",
+ "OLMoBlock",
+ "OLMoSequentialBlock",
+ "OLMo",
+ "OLMoOutput",
+ "OLMoGenerateOutput",
+]
+
+
+log = logging.getLogger(__name__)
+
+
+class CoatOLMoBeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, fused_dims: tuple):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.ln_normalized_shape = config.d_model
+ self.att_proj = nn.Linear(config.d_model, sum(fused_dims), bias=config.include_bias, device=config.init_device)
+
+ self.attn_norm = LayerNorm.build(config)
+
+ def forward(self, re_x, x, s):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch)
+ return _CoatOLMoBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.att_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch
+ )
+ return _CoatOLMoBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.att_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _CoatOLMoBeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
+ in_x, in_s, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None
+ ctx.weight = weight1_origin, weight1_s
+ else:
+ ctx.weight = weight1_t, weight1_s
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, flash_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s = ctx.weight
+ group_size = ctx.group_size
+ mean, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ flash_g, flash_gs, flash_g_t = fp8_quantize_pertensor_transpose(
+ flash_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ fc1_g, att_proj_wg = fp8_linear_backward(
+ ln_x_t, ln_s, flash_g, flash_gs, flash_g_t, weight1_t, weight1_s, group_size
+ )
+
+ # LayerNorm
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return re_g, in_g, in_sg_g16, att_proj_wg, None, None, None, None, None, None, None, None, None
+
+
+class CoatOLMoAfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.attn_out = nn.Linear(config.d_model, config.d_model, bias=config.include_bias, device=config.init_device)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight2_s = self.prepare_weight(self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch)
+
+ return _CoatOLMoAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.attn_out.weight,
+ None,
+ None,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatOLMoAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.attn_out.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _CoatOLMoAfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight2_origin, weight2, weight2_t, weight2_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight2 is None # memory efficient
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ fc2_x = fp8_linear_forward(flash_qx, flash_s, weight2, weight2_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight2_t is None
+ ctx.weight = weight2_origin, weight2_s
+ else:
+ ctx.weight = weight2_t, weight2_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight2_t, weight2_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ fc2_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
+ )
+
+ return fp_grad, fc2_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class CoatOLMoMLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.ln_normalized_shape = config.d_model
+ self.act_output_multiplier = 0.5 if config.activation_type == ActivationType.swiglu else 1
+ self.ff_proj = nn.Linear(config.d_model, hidden_size, bias=config.include_bias, device=config.init_device)
+ self.ff_out = nn.Linear(
+ int(self.act_output_multiplier * hidden_size),
+ config.d_model,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ self.training = True
+
+ # below is only used when training = False
+ self.ff_norm = LayerNorm.build(config)
+ self.act = Activation.build(config)
+ assert (self.act.output_multiplier * hidden_size) % 1 == 0
+
+ def forward(self, re_x, x, s):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch)
+
+ return _CoatOLMoMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.ff_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.ff_out.weight,
+ None,
+ None,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatOLMoMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.ff_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.ff_out.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _CoatOLMoMLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
+ in_x, in_s, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ fc1_x, fc1_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size)
+
+ # NOTE: Becareful of the order
+ up_x, gate_x = fc1_x.chunk(2, dim=-1)
+ up_s, gate_s = fc1_s.chunk(2, dim=-1)
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight2 is None: # memory efficient
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ fc2_x = fp8_linear_forward(mul_x, mul_s, weight2, weight2_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s) = ctx.weight
+ group_size = ctx.group_size
+ mean, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight2_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2) = fp8_mul_backward(silu_x, silu_s, up_x, up_s, fc2_g, group_size, fwobits["babit"])
+
+ # Silu activation
+ silu_g, silu_gs = fp8_silu_backward(gate_x, gate_s, mul_g1, group_size, fwobits["babit"])
+
+ # Prepare the input of Linear Layer. NOTE: Becareful of the order
+ gateup_g = torch.cat([mul_g2, silu_g], dim=-1)
+ gateup_gs = torch.cat([mul_gs2, silu_gs])
+ gateup_gs = torch.max(gateup_gs)
+
+ gateup_g, gateup_gs, gateup_g_t = fp8_division_transpose(
+ gateup_g, group_size, fwobits["babit"], gateup_gs, stochastic=False
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, gateup_g, gateup_gs, gateup_g_t, weight1_t, weight1_s, group_size
+ )
+
+ # layerNorm
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class CoatOLMoBlock(nn.Module):
+ """
+ A base class for transformer block implementations.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__()
+ self.layer_id = layer_id
+ self.config = config
+ self.qargs = qargs
+ self.hidden_size = (
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
+ )
+ self.__cache = cache
+ assert config.d_model % config.n_heads == 0
+
+ self._activation_checkpoint_fn: Callable | None = None
+
+ # Dropout.
+ self.dropout = Dropout(config.residual_dropout)
+
+ # Layer norms.
+ self.k_norm: LayerNormBase | None = None
+ self.q_norm: LayerNormBase | None = None
+ if config.attention_layer_norm:
+ assert config.effective_n_kv_heads is not None
+ self.k_norm = LayerNormBase.build(
+ config,
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
+ elementwise_affine=config.attention_layer_norm_with_affine,
+ )
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
+
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
+ if config.clip_qkv is not None:
+ assert config.clip_qkv > 0
+
+ # Activation function.
+ self.act = Activation.build(config)
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+ if not self.qargs.use_quantize_model:
+ # Attention output projection.
+ self.attn_out = nn.Linear(
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
+ )
+
+ # Feed-forward output projection.
+ self.ff_out = nn.Linear(
+ int(self.act.output_multiplier * self.hidden_size),
+ config.d_model,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ self.ff_out._is_residual = True # type: ignore
+
+ # Rotary embeddings.
+ if self.config.rope:
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
+
+ self.flash_attn_func = None
+ self.flash_attn_varlen_func = None
+ if config.flash_attention:
+ try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
+
+ self.flash_attn_func = flash_attn_func
+ self.flash_attn_varlen_func = flash_attn_varlen_func
+ except ModuleNotFoundError:
+ pass
+
+ def reset_parameters(self):
+ if self.k_norm is not None:
+ self.k_norm.reset_parameters()
+ if self.q_norm is not None:
+ self.q_norm.reset_parameters()
+
+ if not self.qargs.use_quantize_model:
+ if self.config.init_fn == InitFnType.normal:
+ attn_out_std = ff_out_std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+
+ elif self.config.init_fn == InitFnType.mitchell:
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
+ ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ elif self.config.init_fn == InitFnType.full_megatron:
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
+ init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
+
+ def set_activation_checkpointing(
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
+ ):
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
+ self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
+ else:
+ self._activation_checkpoint_fn = None
+
+ @classmethod
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
+ target_dtype = input_dtype
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
+ # `is_autocast_cpu_enabled()` for CPU autocast.
+ # See https://github.com/pytorch/pytorch/issues/110966.
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+ target_dtype = torch.get_autocast_cpu_dtype()
+ if bias.dtype != target_dtype:
+ bias = bias.to(target_dtype)
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
+ return bias
+
+ def _scaled_dot_product_attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Computes scaled dot product attention on query, key and value tensors, using an optional
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
+ """
+ if max_doc_len is not None and cu_doc_lens is not None:
+ assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
+ assert attn_mask is None, "attn-mask is currently not supported with document masking"
+ B, T, D = q.size(0), q.size(2), q.size(3)
+ r = self.flash_attn_varlen_func(
+ q.transpose(1, 2).view(B * T, -1, D),
+ k.transpose(1, 2).view(B * T, -1, D),
+ v.transpose(1, 2).view(B * T, -1, D),
+ cu_doc_lens,
+ cu_doc_lens,
+ max_doc_len,
+ max_doc_len,
+ dropout_p=dropout_p,
+ causal=is_causal,
+ )
+ return r.view(B, T, -1, D).transpose(1, 2)
+ elif self.flash_attn_func is not None and attn_mask is None:
+ r = self.flash_attn_func(
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
+ )
+ return r.transpose(1, 2)
+ else:
+ # torch's sdpa doesn't support GQA, so we're doing this
+ assert k.size(1) == v.size(1)
+ num_kv_heads = k.size(1)
+ num_q_heads = q.size(1)
+ if num_q_heads != num_kv_heads:
+ assert num_q_heads % num_kv_heads == 0
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+
+ return F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ )
+
+ def attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ B, T, C = q.size() # batch size, sequence length, d_model
+ dtype = k.dtype
+
+ # Optionally apply layer norm to keys and queries.
+ if self.q_norm is not None and self.k_norm is not None:
+ q = self.q_norm(q).to(dtype=dtype)
+ k = self.k_norm(k).to(dtype=dtype)
+
+ # Move head forward to be next to the batch dim.
+ # shape: (B, nh, T, hs)
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
+ # shape: (B, n_kv_h, T, hs)
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+ # shape: (B, n_kv_h, T, hs)
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ present = (k, v) if use_cache else None
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
+
+ if self.config.rope:
+ # Apply rotary embeddings.
+ q, k = self.rotary_emb(q, k)
+
+ if attention_bias is not None:
+ # Resize and cast attention bias.
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
+ # cause the SDP attn function to produce NaNs.
+ attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
+
+ # Get the attention scores.
+ # shape: (B, nh, T, hs)
+ att = self._scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attention_bias,
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
+ is_causal=attention_bias is None,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ # Re-assemble all head outputs side-by-side.
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
+
+ # Apply output projection. NOTE: We move the attn output outside of this attention function
+ return att, present
+
+ @abstractmethod
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_bias: torch.FloatTensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ raise NotImplementedError
+
+ @classmethod
+ def build(cls, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache) -> OLMoBlock:
+ if config.block_type == BlockType.sequential:
+ return CoatOLMoSequentialBlock(layer_id, config, qargs, cache)
+ elif config.block_type == BlockType.llama:
+ return CoatOLMoLlamaBlock(layer_id, config, qargs, cache)
+ else:
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
+
+
+class CoatOLMoSequentialBlock(CoatOLMoBlock):
+ """
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+ (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``,
+ use the flag `norm_after`.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__(layer_id, config, qargs, cache)
+ # Attention input projection. Projects x -> (q, k, v)
+
+ assert not self.config.norm_after, "COAT currently does not support PostNorm"
+
+ head_dim = config.d_model // config.n_heads
+ self.fused_dims = (
+ config.d_model,
+ config.effective_n_kv_heads * head_dim,
+ config.effective_n_kv_heads * head_dim,
+ )
+
+ if self.qargs.use_quantize_model:
+ self.BeforeAttention = CoatOLMoBeforeAttentionResidual(config, qargs, self.layer_id, self.fused_dims)
+ self.AfterAttention = CoatOLMoAfterAttentionResidual(config, qargs, self.layer_id)
+ self.MLPResidual = CoatOLMoMLPResidual(config, qargs, self.layer_id, self.hidden_size)
+ else:
+ self.att_proj = nn.Linear(
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
+ )
+ # Feed-forward input projection.
+ self.ff_proj = nn.Linear(
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+ )
+
+ # Layer norms.
+ self.attn_norm = LayerNorm.build(config, size=config.d_model)
+ self.ff_norm = LayerNorm.build(config, size=config.d_model)
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ self.attn_norm.reset_parameters()
+ self.ff_norm.reset_parameters()
+ # NOTE: the standard deviation for these weights does not depend on the layer.
+
+ if self.qargs.use_quantize_model: # The initialization appears here, not in CoatOLMoBlock's reset_parameters
+ if self.config.init_fn == InitFnType.normal:
+ attn_out_std = ff_out_std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+
+ elif self.config.init_fn == InitFnType.mitchell:
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
+ ff_out_std = 1 / (math.sqrt(2 * self.MLPResidual.ff_out.in_features * (self.layer_id + 1)))
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ elif self.config.init_fn == InitFnType.full_megatron:
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.AfterAttention.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
+ init_normal(self.MLPResidual.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
+
+ if self.config.init_fn == InitFnType.normal:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+ elif self.config.init_fn == InitFnType.mitchell:
+ std = 1 / math.sqrt(self.config.d_model)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ elif self.config.init_fn == InitFnType.full_megatron:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ if not self.qargs.use_quantize_model:
+ init_normal(self.att_proj, std, cutoff_factor)
+ init_normal(self.ff_proj, std, cutoff_factor)
+ else:
+ init_normal(self.BeforeAttention.att_proj, std, cutoff_factor)
+ init_normal(self.MLPResidual.ff_proj, std, cutoff_factor)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ qx: torch.Tensor,
+ sx: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ # Get query, key, value projections.
+ # shape:
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_heads)
+ # - for group query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
+
+ # import IPython
+ # IPython.embed()
+
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qkv = self.BeforeAttention(x, qx, sx)
+ else:
+ # apply norm before
+ h = self.attn_norm(x)
+
+ qkv = self.BeforeAttention.att_proj(h)
+
+ if self.config.clip_qkv is not None:
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
+
+ # Get attention scores.
+ att, cache = self.attention(
+ q,
+ k,
+ v,
+ attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ # import IPython
+ # IPython.embed()
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qx, sx = self.AfterAttention(x, att)
+ else:
+ att = self.AfterAttention.attn_out(att)
+
+ # Add attention scores.
+ # shape: (B, T, C)
+ x = x + self.dropout(att)
+
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qx, sx = self.MLPResidual(x, qx, sx)
+ else:
+ # Add feed-forward projection.
+ # shape: (batch_size, seq_len, d_model)
+ og_x = x
+
+ x = self.ff_norm(x)
+
+ x = self.MLPResidual.ff_proj(x)
+
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
+ else:
+ x = self.act(x)
+ x = self.MLPResidual.ff_out(x)
+
+ x = self.dropout(x)
+ x = og_x + x
+
+ # import IPython
+ # IPython.embed()
+
+ return x, qx, sx, cache
+
+
+class CoatOLMoLlamaBlock(OLMoBlock):
+ """
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
+ but some operations have slightly different implementations to imitate the
+ behavior of Llama.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__(layer_id, config, qargs, cache)
+ # Layer norms.
+ self.attn_norm = LayerNorm.build(config)
+ self.ff_norm = LayerNorm.build(config)
+ self.__cache = cache
+
+ # Attention input projection. Projects x -> (q, k, v)
+ if config.multi_query_attention:
+ q_proj_out_dim = config.d_model
+ k_proj_out_dim = config.d_model // config.n_heads
+ v_proj_out_dim = config.d_model // config.n_heads
+ else:
+ q_proj_out_dim = config.d_model
+ k_proj_out_dim = config.d_model
+ v_proj_out_dim = config.d_model
+ self.q_proj = nn.Linear(config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device)
+ self.k_proj = nn.Linear(config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device)
+ self.v_proj = nn.Linear(config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device)
+
+ # Feed-forward input projection.
+ self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device)
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ self.attn_norm.reset_parameters()
+ self.ff_norm.reset_parameters()
+ # NOTE: the standard deviation for these weights does not depend on the layer.
+
+ if self.config.init_fn == InitFnType.normal:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+ elif self.config.init_fn == InitFnType.mitchell:
+ std = 1 / math.sqrt(self.config.d_model)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ elif self.config.init_fn == InitFnType.full_megatron:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.q_proj, std, cutoff_factor)
+ init_normal(self.k_proj, std, cutoff_factor)
+ init_normal(self.v_proj, std, cutoff_factor)
+ init_normal(self.ff_proj, std, cutoff_factor)
+
+ def _scaled_dot_product_attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if max_doc_len is not None or cu_doc_lens is not None:
+ raise NotImplementedError(f"attention document masking is not implemented for {self.__class__.__name__}")
+
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
+
+ if is_causal:
+ assert attn_mask is None
+
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
+ elif attn_mask is not None:
+ attn_bias = attn_mask.to(q.dtype)
+ else:
+ attn_bias = torch.zeros_like(attn_weights)
+
+ attn_weights += attn_bias
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
+ return torch.matmul(attn_weights, v)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ qx: torch.Tensor,
+ sx: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ # Get query, key, value projections.
+ # shape:
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_heads)
+ x_normed = self.attn_norm(x)
+ q = self.q_proj(x_normed)
+ k = self.k_proj(x_normed)
+ v = self.v_proj(x_normed)
+
+ if self.config.clip_qkv is not None:
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Get attention scores.
+ att, cache = self.attention(
+ q,
+ k,
+ v,
+ attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ att = self.attn_out(att) # NOTE: we move the attn_out outside the self.attention module
+
+ # Add attention scores.
+ # shape: (B, T, C)
+ x = x + self.dropout(att)
+
+ # Add feed-forward projection.
+ # shape: (batch_size, seq_len, d_model)
+ og_x = x
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
+ else:
+ x = self.ff_norm(x)
+ x = self.ff_proj(x)
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
+ else:
+ x = self.act(x)
+ x = self.ff_out(x)
+ x = self.dropout(x)
+ x = og_x + x
+
+ return x, cache
+
+
+class CoatOLMoBlockGroup(nn.ModuleList):
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Iterable[nn.Module] | None = None):
+ super().__init__(modules)
+ self.config = config
+ self.layer_offset = layer_offset
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_bias: torch.FloatTensor | None = None,
+ layers_past: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]] | None]:
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
+ for block_idx, block in enumerate(self):
+ layer_past = None if layers_past is None else layers_past[block_idx]
+ block_idx += self.layer_offset
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
+ # shape: (batch_size, seq_len, d_model)
+ x, cache = self._activation_checkpoint_fn( # type: ignore
+ block,
+ x,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ else:
+ # shape: (batch_size, seq_len, d_model)
+ x, cache = block(
+ x,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.append(cache)
+ return x, attn_key_values
+
+ def reset_parameters(self):
+ for block in self:
+ block.reset_parameters()
+
+ def set_activation_checkpointing(
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
+ ):
+ self.activation_checkpointing_strategy = strategy
+ for block in self:
+ block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
+
+
+class CoatOLMo(nn.Module):
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, init_params: bool = True):
+ super().__init__()
+ self.config = config
+ self.qargs = qargs
+ self.__cache = BufferCache()
+
+ # Validate config.
+ if self.config.alibi and self.config.flash_attention:
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
+
+ if self.config.alibi and self.config.rope:
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
+
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
+ if self.config.embedding_size < self.config.vocab_size:
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
+ elif self.config.embedding_size % 128 != 0:
+ import warnings
+
+ warnings.warn(
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
+ )
+
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
+
+ if not (
+ 0 < self.config.block_group_size <= self.config.n_layers
+ and self.config.n_layers % self.config.block_group_size == 0
+ ):
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
+
+ torch.backends.cuda.enable_flash_sdp(True)
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
+
+ self.transformer = nn.ModuleDict(
+ dict(
+ wte=nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device),
+ emb_drop=Dropout(config.embedding_dropout),
+ ln_f=LayerNorm.build(config),
+ )
+ )
+
+ blocks = [CoatOLMoBlock.build(i, config, qargs, self.__cache) for i in range(config.n_layers)]
+ if self.config.block_group_size > 1:
+ block_groups = [
+ CoatOLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
+ for i in range(0, config.n_layers, config.block_group_size)
+ ]
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
+ else:
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+ if not (self.config.alibi or self.config.rope):
+ self.transformer.update(
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
+ )
+ if not config.weight_tying:
+ self.transformer.update(
+ {
+ "ff_out": nn.Linear(
+ config.d_model,
+ config.embedding_size or config.vocab_size,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ }
+ )
+ if config.embedding_layer_norm:
+ self.transformer.update({"emb_norm": LayerNorm.build(config)})
+
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
+ if init_params and self.config.init_device != "meta":
+ self.reset_parameters()
+ self.__num_fwd_flops: int | None = None
+ self.__num_bck_flops: int | None = None
+
+ # Warm up cache.
+ if self.config.alibi:
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
+
+ # Quantize
+ self.quantize_input_before_block = Coat_quantize_bgn(qargs)
+ self.quantize_output_after_block = Coat_quantize_end(qargs)
+
+ set_activation_checkpointing = OLMo.set_activation_checkpointing
+ device = OLMo.device
+ reset_parameters = OLMo.reset_parameters
+ get_alibi_attention_bias = OLMo.get_alibi_attention_bias
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ input_embeddings: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ attention_bias: torch.Tensor | None = None,
+ past_key_values: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None,
+ use_cache: bool = False,
+ last_logits_only: bool = False,
+ output_hidden_states: bool | None = None,
+ doc_lens: torch.Tensor | None = None,
+ max_doc_lens: Sequence[int] | None = None,
+ ) -> OLMoOutput:
+ """
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
+ embeddings. When provided, it is treated as the output of the input embedding layer.
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
+ which input IDs are masked. A `1` value in the mask means that
+ the corresponding input ID should *not* be ignored. A `0` means
+ that the corresponding input ID is masked.
+
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
+ library.
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
+ to introduce causal or other biases.
+
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
+ element in the sequence.
+
+ If the tensor is a float tensor, it will just be added to the attention
+ scores before the softmax.
+
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
+ :param past_key_values: Pre-computed keys and values for each attention block.
+ Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ :param use_cache: If `True`, return key and value tensors for each block.
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
+ This can speed up decoding when you only care about the next token.
+ :param doc_lens: Document lengths to use in attention for intra-document masking.
+ Shape `(batch_size, max_docs)`.
+ :param max_doc_lens: Maximum document length for each instance in the batch.
+ """
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+
+ if past_key_values:
+ assert len(past_key_values) == self.config.n_layers
+
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
+ if past_key_values is None:
+ past_length = 0
+ else:
+ past_length = past_key_values[0][0].size(-2)
+
+ max_doc_len: int | None = None
+ cu_doc_lens: torch.Tensor | None = None
+ if doc_lens is not None and max_doc_lens is not None:
+ max_doc_len = max(max_doc_lens)
+ cu_doc_lens = get_cumulative_document_lengths(doc_lens)
+
+ # Get embeddings of input.
+ # shape: (batch_size, seq_len, d_model)
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
+
+ # Apply embedding layer norm.
+ if self.config.embedding_layer_norm:
+ x = self.transformer.emb_norm(x)
+
+ if not (self.config.alibi or self.config.rope):
+ # Get positional embeddings.
+ # shape: (1, seq_len)
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
+ # shape: (1, seq_len, d_model)
+ pos_emb = self.transformer.wpe(pos) # type: ignore
+ x = pos_emb + x
+
+ # Apply dropout.
+ # shape: (batch_size, seq_len, d_model)
+ x = self.transformer.emb_drop(x) # type: ignore
+
+ # Transform the attention mask into what the blocks expect.
+ if attention_mask is not None:
+ # shape: (batch_size, 1, 1, seq_len)
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
+
+ # Merge attention mask with attention bias.
+ if (
+ attention_bias is not None
+ or attention_mask is not None
+ or self.config.alibi
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
+ # scores correctly.
+ or past_key_values is not None
+ ):
+ if attention_bias is None and self.config.alibi:
+ attention_bias = get_causal_attention_bias(
+ self.__cache, past_length + seq_len, x.device
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
+ elif attention_bias is None:
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
+ elif attention_bias.dtype in (torch.int8, torch.bool):
+ attention_bias = attention_bias.to(dtype=torch.float)
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
+
+ # Transform to the right shape and data type.
+ mask_len = seq_len
+ if attention_mask is not None:
+ mask_len = attention_mask.shape[-1]
+ elif past_key_values is not None:
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
+
+ # Add in the masking bias.
+ if attention_mask is not None:
+ attention_bias = attention_bias + attention_mask
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
+ # it can produce NaNs.
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
+
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
+
+ # decoder layers
+ all_hidden_states = []
+
+ # Prepare the input for COAT decoderlayer
+ x, qx, sx = self.quantize_input_before_block(x)
+
+ # Apply blocks one-by-one.
+ if self.config.block_group_size == 1:
+ for block_idx, block in enumerate(self.transformer.blocks):
+ if output_hidden_states:
+ # add hidden states
+ all_hidden_states.append(x)
+
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
+ # shape: (batch_size, seq_len, d_model)
+ x, qx, sx, cache = self._activation_checkpoint_fn(
+ block,
+ x,
+ qx,
+ sx,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ else:
+ # shape: (batch_size, seq_len, d_model)
+ x, qx, sx, cache = block(
+ x,
+ qx,
+ sx,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.append(cache)
+ else:
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
+ if output_hidden_states:
+ # add hidden states
+ all_hidden_states.append(x)
+
+ layers_past = (
+ None
+ if past_key_values is None
+ else past_key_values[
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
+ ]
+ )
+ x, cache = block_group(
+ x,
+ attention_bias=attention_bias,
+ layers_past=layers_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.extend(cache)
+
+ # Summarize the output of the Decoder Layer
+ x = self.quantize_output_after_block(x, qx, sx)
+
+ if last_logits_only:
+ # shape: (batch_size, 1, d_model)
+ x = x[:, -1, :].unsqueeze(1)
+
+ # Apply final layer norm.
+ # shape: (batch_size, seq_len or 1, d_model)
+ x = self.transformer.ln_f(x) # type: ignore
+ if output_hidden_states:
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
+ all_hidden_states.append(x)
+
+ # Get logits.
+ # shape: (batch_size, seq_len or 1, vocab_size)
+ if self.config.weight_tying:
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
+ else:
+ logits = self.transformer.ff_out(x) # type: ignore
+ if self.config.scale_logits:
+ logits.mul_(1 / math.sqrt(self.config.d_model))
+
+ return OLMoOutput(
+ logits=logits,
+ attn_key_values=attn_key_values,
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
+ )
+
+ def get_fsdp_wrap_policy(self, wrap_strategy: FSDPWrapStrategy | None = None):
+ if wrap_strategy is None:
+ return None
+
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
+ # but not other linear layers within a block.
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
+ # return True in 'recurse' mode for simplicity.
+ size_based_module_to_wrap = {self.transformer.wte}
+ if hasattr(self.transformer, "ff_out"):
+ size_based_module_to_wrap.add(self.transformer.ff_out)
+
+ if wrap_strategy == FSDPWrapStrategy.by_block:
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlock)
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, (CoatOLMoBlock,)) or module in size_based_module_to_wrap
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
+ if self.config.block_group_size <= 1:
+ raise OLMoConfigurationError(
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
+ )
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlockGroup)
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
+ if self.config.block_group_size <= 1:
+ raise OLMoConfigurationError(
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
+ )
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, (CoatOLMoBlockGroup,)) or module in size_based_module_to_wrap
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
+
+ return size_based_auto_wrap_policy
+ elif wrap_strategy in {
+ FSDPWrapStrategy.one_in_two,
+ FSDPWrapStrategy.one_in_three,
+ FSDPWrapStrategy.one_in_four,
+ FSDPWrapStrategy.one_in_five,
+ }:
+ c = {
+ FSDPWrapStrategy.one_in_two: 2,
+ FSDPWrapStrategy.one_in_three: 3,
+ FSDPWrapStrategy.one_in_four: 4,
+ FSDPWrapStrategy.one_in_five: 5,
+ }[wrap_strategy]
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlock) and module.layer_id % c == 0
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ else:
+ raise NotImplementedError(wrap_strategy)
+
+ num_params = OLMo.num_params
+
+ @property
+ def num_fwd_flops(self):
+ if self.__num_fwd_flops:
+ return self.__num_fwd_flops
+
+ # embedding table is just a lookup in the forward pass
+ n_params = self.num_params(include_embedding=False)
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
+ # this gets us FLOPs / token
+ params_flops_per_token = 2 * n_params
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
+ attn_flops_per_token = self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length)
+ self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token
+ return self.__num_fwd_flops
+
+ @property
+ def num_bck_flops(self):
+ if self.__num_bck_flops:
+ return self.__num_bck_flops
+
+ n_params = self.num_params()
+ params_flops_per_token = 4 * n_params
+ attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length)
+ self.__num_bck_flops = params_flops_per_token + attn_flops_per_token
+ return self.__num_bck_flops
+
+ generate = OLMo.generate
+
+ @classmethod
+ def from_checkpoint(
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: CheckpointType | None = None
+ ) -> CoatOLMo:
+ """
+ Load an OLMo model from a checkpoint.
+ """
+ from olmo.util import resource_path
+
+ # Guess checkpoint type.
+ if checkpoint_type is None:
+ try:
+ if resource_path(checkpoint_dir, "model.pt").is_file():
+ checkpoint_type = CheckpointType.unsharded
+ else:
+ checkpoint_type = CheckpointType.sharded
+ except FileNotFoundError:
+ checkpoint_type = CheckpointType.sharded
+
+ # Load config.
+ config_path = resource_path(checkpoint_dir, "config.yaml")
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
+
+ if checkpoint_type == CheckpointType.unsharded:
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
+ model_config.init_device = "cpu"
+ model = CoatOLMo(model_config)
+
+ # Load state dict directly to target device.
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
+ state_dict = torch.load(state_dict_path, map_location="cpu")
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
+ model = model.to(torch.device(device))
+ else:
+ train_config = TrainConfig.load(config_path)
+ if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state # type: ignore
+
+ model_config.init_device = device
+ model = CoatOLMo(model_config)
+ load_model_and_optim_state(checkpoint_dir, model)
+ else:
+ # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
+ from olmo.checkpoint import load_model_state
+
+ # Initialize model on target device. In this case the state dict is loaded in-place
+ # so it's not necessary to start on CPU if the target device is a GPU.
+ model_config.init_device = device
+ model = CoatOLMo(model_config)
+
+ # Load state dict in place.
+ load_model_state(checkpoint_dir, model)
+
+ return model.eval()
+
+ def _make_state_dict_compatible(
+ self, state_dict: dict[str, torch.Tensor]
+ ) -> tuple[dict[str, torch.Tensor], dict[str, set[str]]]:
+ """
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
+ be loaded.
+
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
+ to make a corresponding optimizer state dict compatible as well.
+ """
+ import re
+ from fnmatch import fnmatch
+
+ new_keys_to_og_keys: dict[str, str] = {}
+
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
+ # fine without the prefixes. This also simplifies the other steps below.
+ for key in list(state_dict.keys()):
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = key
+
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
+ if self.config.block_type == BlockType.sequential:
+ for key in list(state_dict.keys()):
+ if fnmatch(key, "transformer.*.norm.weight"):
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.*.norm.bias"):
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+
+ # Realquantization will change the place the linear layers happen
+ if self.qargs.use_quantize_model == "coat_real":
+ for key in list(state_dict.keys()):
+ if fnmatch(key, "transformer.blocks.*.att_proj.weight") and "BeforeAttention" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("att_proj.weight", "BeforeAttention.att_proj.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.attn_out.weight") and "AfterAttention" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("attn_out.weight", "AfterAttention.attn_out.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.ff_proj.weight") and "MLPResidual" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("ff_proj.weight", "MLPResidual.ff_proj.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.ff_out.weight") and "MLPResidual" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("ff_out.weight", "MLPResidual.ff_out.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+
+ # For loading a state dict that was saved with a different `block_group_size`.
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
+ state_dict_block_group_size = len(
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
+ )
+ else:
+ state_dict_block_group_size = 1
+ if self.config.block_group_size != state_dict_block_group_size:
+ log.info(
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
+ f"group size {self.config.block_group_size}"
+ )
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
+ # and then (re-)group them into the right block sizes.
+ if state_dict_block_group_size > 1:
+ for key in list(state_dict.keys()):
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
+ state_dict[
+ (
+ new_key := key.replace(
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
+ )
+ )
+ ] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
+
+ if self.config.block_group_size > 1:
+ # Group the state dict blocks into the right block size.
+ for key in list(state_dict.keys()):
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
+ block_idx = int(m.group(1))
+ group_idx, group_block_idx = (
+ block_idx // self.config.block_group_size,
+ block_idx % self.config.block_group_size,
+ )
+ state_dict[
+ (
+ new_key := key.replace(
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
+ )
+ )
+ ] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
+
+ og_keys_to_new: dict[str, set[str]] = defaultdict(set)
+ for new_key, og_key in new_keys_to_og_keys.items():
+ og_keys_to_new[og_key].add(new_key)
+
+ return state_dict, og_keys_to_new
diff --git a/llava/model/coat/activation/real_quantization/__init__.py b/llava/model/coat/activation/real_quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61875efb5d6a495fdab36438ea99a8784a88c4cf
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Activation
+# Utils
+from ._dequantize import fp8_dequantize
+from ._division import fp8_division
+from ._division_transpose import fp8_division_transpose
+from ._quantize import fp8_quantize
+from ._quantize_pertensor import fp8_quantize_pertensor
+from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
+from ._transpose import fp8_transpose
+from .add_bwd import fp8_add_Ifp_Ifp_Ofp_Opt
+from .add_fwd import fp8_add_Ifp_Ifp_Ofp_Og16
+
+# Normalization
+from .func_layernorm_noparam import fp8_layernorm_noparam_backward, fp8_layernorm_noparam_forward
+from .func_quantize import Coat_quantize_bgn, Coat_quantize_end
+from .func_rmsnorm import fp8_rmsnorm_backward, fp8_rmsnorm_forward
+from .gelu_bwd import fp8_gelu_backward
+from .gelu_fwd import fp8_gelu_forward
+
+# linear and add
+from .linear import fp8_linear_backward, fp8_linear_forward
+from .mul_bwd import fp8_mul_backward
+from .mul_fwd import fp8_mul_forward
+from .silu_bwd import fp8_silu_backward
+from .silu_fwd import fp8_silu_forward
diff --git a/llava/model/coat/activation/real_quantization/_dequantize.py b/llava/model/coat/activation/real_quantization/_dequantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..4347932f9c3691ffada9d59258337aac75b74942
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_dequantize.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_dequantize_kernel(
+ output_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of output stride
+ output_stride_0,
+ output_stride_1, # output stride
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ output = input * scale_input
+
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+ output = output.to(output_ptr.dtype.element_ty)
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def fp8_dequantize(x, s_x, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_dequantize_kernel[grid](
+ y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y
diff --git a/llava/model/coat/activation/real_quantization/_division.py b/llava/model/coat/activation/real_quantization/_division.py
new file mode 100644
index 0000000000000000000000000000000000000000..335ff1a67795dc858a740c6618466ff79fb554cb
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_division.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Quantize and Transpose Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_kernel(
+ output_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit: tl.constexpr,
+ m_bit: tl.constexpr, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input_stride_0, input_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+@triton.jit
+def _stochastic_rounding(output, noise, e_bit: tl.constexpr, m_bit: tl.constexpr):
+ subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit)
+ # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1)
+
+ output_int32 = tl.cast(output, tl.int32, bitcast=True)
+ output_int32 = output_int32 & 0x7F800000
+ output_float32 = tl.cast(output_int32, tl.float32, bitcast=True)
+ output_exp = tl.maximum(output_float32, subnormal_min)
+
+ noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * (
+ 1 - tl.exp2(m_bit)
+ ) # 2^m_bit for normal, 1 for subnormal
+
+ noise = output_exp * noise / noise_rescale
+ sign = 1 - 2 * libdevice.signbit(output)
+ output = tl.abs(output) + noise
+
+ minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal
+ output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp)
+
+ return output
+
+
+def fp8_division(x, QB, fp8type, s_y=None, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ # noise = torch.zeros_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_kernel[grid](
+ y,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y # y_t is expected to be 2D tensor
diff --git a/llava/model/coat/activation/real_quantization/_division_transpose.py b/llava/model/coat/activation/real_quantization/_division_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fbabaa31e4f47482c98d69431a9dfae488a2f15
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_division_transpose.py
@@ -0,0 +1,215 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import _stochastic_rounding
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Division and Transpose Operator"""
+"""Input uses full-precision/BF16"""
+"""Output uses per tensor quantization"""
+"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)
+ # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], #
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_transpose_kernel(
+ output_ptr,
+ output_t_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ output_t_stride_0,
+ output_t_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ ONLY_TRANSPOSED: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input_stride_0, input_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+ # tl.device_print("3: ", output)
+ output_t = tl.trans(output)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ output_t_block_ptr = tl.make_block_ptr(
+ base=output_t_ptr,
+ shape=(N, M),
+ strides=(output_t_stride_0, output_t_stride_1),
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
+ block_shape=(BLOCK_N, BLOCK_M),
+ order=(1, 0),
+ )
+ if not ONLY_TRANSPOSED:
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+ tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1))
+
+
+def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ # noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ y_t = torch.empty((N, M), dtype=fp8type, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ # print("Warning: do not specify s_y in fp8_division_transpose")
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_transpose_kernel[grid](
+ y,
+ y_t,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ y_t.stride(0),
+ y_t.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ ONLY_TRANSPOSED=only_transposed,
+ )
+
+ if not only_transposed:
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y, y_t # y_t is expected to be 2D tensor
+ else:
+ return y_t, s_y
diff --git a/llava/model/coat/activation/real_quantization/_memory_io.py b/llava/model/coat/activation/real_quantization/_memory_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04d2a7afcf749b634cbd5d88c6d99b42030d5ff
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_memory_io.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+CONST_BLOCK = 32
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5, 6]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16, 32]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.jit
+def bench_memory_io_kernel_forward(
+ output_ptr,
+ input_ptr,
+ M,
+ N,
+ B: tl.constexpr,
+ input_stride_0,
+ input_stride_1,
+ output_stride_0,
+ output_stride_1,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_M = tl.cdiv(M, BLOCK_M)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = input * 2
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ output = output.to(output_ptr.type.element_ty)
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def bench_memory_io_forward(x, B):
+ # defining the input and output tensor
+ M, N = x.shape
+
+ y = torch.empty_like(x, dtype=x.dtype)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ bench_memory_io_kernel_forward[grid](
+ y,
+ x,
+ M,
+ N,
+ B,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+ return y
+
+
+configs = []
+for SL in [8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="dtype",
+ line_vals=[torch.int8, torch.float16, torch.float32],
+ line_names=["float8", "float16", "float32"],
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
+ ylabel="time-cost",
+ plot_name=f"INT8GELU",
+ args={"SL": SL, "B": CONST_BLOCK, "provider": "triton", "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ SL, CDIM, B, provider, dtype, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(SL, CDIM, dtype=torch.float32).cuda()
+ x = x.to(dtype)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ bench_memory_io_forward(x, B)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.GELU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3)
+ bench_load_store.run(print_data=True)
diff --git a/llava/model/coat/activation/real_quantization/_quantize.py b/llava/model/coat/activation/real_quantization/_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1385b7dc42f25f0505a42cfecad48d9fa6eaa51
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+
+ output = output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize(x, QB, fp8type):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y = torch.empty_like(x, dtype=fp8type)
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_kernel[grid](
+ y,
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..46cadbc6e77e2e34c1fafec968aa17c5621da034
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Per Tensor Quantize Operator"""
+"""Input uses full precision"""
+"""Output uses per tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_pertensor_kernel(
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_pertensor(x, QB, fp8type, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_pertensor_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ y, s_y_max = fp8_division(x, QB, fp8type, s_y_max, stochastic=stochastic) # reuse the floating point output y1
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y_max, s_y
diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e91d01445c99928aa7f79227cdbc7706fee1981
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Per Tensor Quantize and Transpose Operator"""
+"""Input uses floating point tensor"""
+"""Output uses per-tensor quantization, returns a non-transpose version and a transpose version"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_pertensor_transpose_kernel(
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_pertensor_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_pertensor_transpose_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(
+ x, QB, fp8type, s_y_max, stochastic=stochastic
+ ) # Stochastic Rounding happens here
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1])
+
+ return qy, s_y_max, qy_t # y_t is expected to be 2D tensor
diff --git a/llava/model/coat/activation/real_quantization/_transpose.py b/llava/model/coat/activation/real_quantization/_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0b8ec2dc499d2bcd661e619d9f54e91daeac0c
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_transpose.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.jit
+def _fp8_transpose_kernel(
+ output_ptr, # output
+ input_ptr, # input
+ M,
+ N, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+
+ output = tl.trans(input)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(N, M),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
+ block_shape=(BLOCK_N, BLOCK_M),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def fp8_transpose(x, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+
+ y = torch.empty((N, M), dtype=x.dtype, device=x.device)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_transpose_kernel[grid](
+ y,
+ x,
+ M,
+ N,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+
+ # Recover 2D to 3D
+ if batched and not transpose_output_2d:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y
diff --git a/llava/model/coat/activation/real_quantization/add_bwd.py b/llava/model/coat/activation/real_quantization/add_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..5951afec125048d8549fd1be0ae204df1a6b222a
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/add_bwd.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Add, useful for backward"""
+"""Input1 (Residual) uses full-precision/BF16"""
+"""Input2 (Backbone) uses full-precision/BF16"""
+"""Output1 uses full-precision/BF16"""
+"""Output2 uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_add_Ifp_Ifp_Ofp_Opt_kernel(
+ output1_ptr, # output
+ output2_scale_ptr,
+ input1_ptr, # input
+ input2_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ input1 = input1.to(tl.float32)
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ input2 = input2.to(tl.float32)
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Add
+ add_output = input1 + input2
+
+ # Quantize the grad 1 - Scale calculation
+ abs_add_output = tl.abs(add_output)
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
+ scale_output2 = max_val / fp8_max
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
+
+ # save the fp add output
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output1_block_ptr, fp_add_output, boundary_check=(0, 1))
+
+ # Quantize
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
+
+ # pointers
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
+
+
+def fp8_add_Ifp_Ifp_Ofp_Opt(x1, x2, QB, fp8type, stochastic=False): # suppose x1 is full precision or BF16
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(x2.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ SN = N // QB
+ assert x1.shape == x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_add_Ifp_Ifp_Ofp_Opt_kernel[grid](
+ y1,
+ s_y2,
+ x1,
+ x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max = fp8_division(y1, QB, fp8type, s_y2_max, stochastic=stochastic) # reuse the floating point output y1
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, s_y2)
diff --git a/llava/model/coat/activation/real_quantization/add_fwd.py b/llava/model/coat/activation/real_quantization/add_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..90aed0b3573e4c1a21b4eecea959f6f537f61be8
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/add_fwd.py
@@ -0,0 +1,219 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Element-wise Add, used in forward pass"""
+"""Input1 (Residual) uses full-precision/BF16"""
+"""Input2 (Backbone) uses full-precision/BF16"""
+"""Output1 uses full-precision/BF16"""
+"""Output2 uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_add_Ifp_Ifp_Ofp_Og16_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr,
+ input1_ptr, # input
+ input2_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ input1 = input1.to(tl.float32)
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ input2 = input2.to(tl.float32)
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Add
+ add_output = input1 + input2
+
+ # Quantize the grad 1 - Scale calculation
+ abs_add_output = tl.abs(add_output)
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
+ scale_output2 = max_val / fp8_max
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
+
+ # save the fp add output
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output1_block_ptr, fp_add_output)
+
+ # Quantize
+ add_output = tl.fdiv(add_output, scale_output2)
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
+
+ add_output = add_output.to(output2_ptr.type.element_ty)
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, add_output, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
+
+
+def fp8_add_Ifp_Ifp_Ofp_Og16(x1, x2, fp8type, QB): # suppose x1 is full precision or BF16
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ SN = int(N / QB) # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
+ y2 = torch.empty_like(x2, dtype=fp8type)
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_add_Ifp_Ifp_Ofp_Og16_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
diff --git a/llava/model/coat/activation/real_quantization/common.py b/llava/model/coat/activation/real_quantization/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..00de715cf3e80eddf74ea401e1bbc9811c3e4970
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/common.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+
+SCALE_MIN_THRES = 1e-10
+
+FP8_MAX_VALUE = {
+ torch.float8_e4m3fn: 448.0,
+ torch.float8_e5m2: 57344.0,
+}
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+convert_fp8_to_embit = {
+ torch.float8_e4m3fn: (4.0, 3.0),
+ torch.float8_e5m2: (5.0, 2.0),
+}
+
+
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64]:
+ for block_n in [64, 128]:
+ for nwarps in [4, 8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+# from .common import SCALE_MIN_THRES, FP8_MAX_VALUE
+# SCALE_MIN_THRES: tl.constexpr,
+# + SCALE_MIN_THRES
+# SCALE_MIN_THRES=SCALE_MIN_THRES,
diff --git a/llava/model/coat/activation/real_quantization/fp8linear.py b/llava/model/coat/activation/real_quantization/fp8linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..32531fcb4bb63e4720e3fc48fb08b1f2d82d54db
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/fp8linear.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import time
+from copy import deepcopy
+from dataclasses import dataclass
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function
+
+from ..utils import quant_get_local_rank
+from ._division_transpose import fp8_division_transpose
+from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
+from .linear import fp8_linear_backward, fp8_linear_forward
+
+
+@dataclass
+class DefaultArgs:
+ fabit: int
+ fwbit: int
+ bobit: int
+
+
+class FP8Linear(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0):
+ super().__init__(in_features, out_features, bias, device)
+
+ if args is None: # I do not want to pass a new argument to OLMo so just use this method
+ args = DefaultArgs(
+ fabit=os.environ["FABIT_FP8Linear"],
+ fwbit=os.environ["FWBIT_FP8Linear"],
+ bobit=os.environ["BOBIT_FP8Linear"],
+ )
+ self.args = deepcopy(args)
+
+ if quant_get_local_rank() == 0:
+ print(f"[qlinear debug] Apply QLinear, {layer_idx}")
+
+ self.layer_idx = layer_idx
+ self.layer_name = None
+
+ def forward(self, Input):
+ if self.training:
+ # if False:
+ output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name)
+ else:
+ output = F.linear(Input, self.weight, self.bias)
+
+ return output
+
+
+# if int(os.environ.get("LOCAL_RANK")) == 0:
+# import IPython
+# IPython.embed()
+# else:
+# import time
+# time.sleep(1000)
+
+# class QuantLinearTE(Function):
+# @staticmethod
+# def forward(ctx, input, weight, bias, args, layer_type):
+# ctx.saved = input, weight, bias, args, layer_type
+# return F.linear(input, weight, bias)
+
+# @staticmethod
+# def backward(ctx, grad_output):
+# input, weight, bias, args, layer_type = ctx.saved
+
+# C_in = input.shape[-1]
+# C_out = grad_output.shape[-1]
+
+# grad_output_flatten = grad_output.reshape(-1, C_out)
+# input_flatten = input.reshape(-1, C_in)
+
+# if grad_output_flatten.dtype == input_flatten.dtype:
+# grad_weight = grad_output_flatten.t().mm(input_flatten)
+# else:
+# grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+# if grad_output_flatten.dtype == weight.dtype:
+# grad_input = grad_output_flatten.mm(weight)
+# else:
+# grad_input = grad_output_flatten.float().mm(weight)
+
+# if bias is not None:
+# grad_bias = grad_output_flatten.sum(0)
+# else:
+# grad_bias = None
+
+# grad_input_transform = grad_input.reshape(input.size())
+
+# return grad_input_transform, grad_weight, grad_bias, None, None
+
+
+class QuantLinearTE(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16)
+ def forward(ctx, input, weight, bias, args, layer_name):
+
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit)
+ Qinput, Iscale, Qinput_t = fp8_quantize_pertensor_transpose(input, 16, args.fabit, transpose_output_2d=True)
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit)
+ Qweight, Wscale, Qweight_t = fp8_quantize_pertensor_transpose(weight, 16, args.fwbit, transpose_output_2d=True)
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name
+ fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias)
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ output = F.linear(input, weight, bias)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}"
+ )
+
+ return fc_output
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_output):
+ Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
+
+ time_bench = os.getenv("TIME_BENCH")
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False)
+ Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_pertensor_transpose(
+ grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True
+ )
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ grad_input, grad_weight = fp8_linear_backward(
+ Qinput_t,
+ Iscale,
+ Qgrad_output,
+ Gscale,
+ Qgrad_output_t,
+ Qweight_t,
+ Wscale,
+ 16,
+ bias,
+ stochastic=False,
+ dgrad_quantize=False,
+ )
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ if bias is not None:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+ else:
+ grad_bias = None
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+
+ # ========== BF16 ==========
+ C_in = Qinput_t.shape[0]
+ C_out = grad_output.shape[-1]
+ grad_output_flatten = grad_output.reshape(-1, C_out)
+ input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16)
+ weight = Qweight_t.t().to(torch.bfloat16)
+
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ if grad_output_flatten.dtype == input_flatten.dtype:
+ _grad_weight = grad_output_flatten.t().mm(input_flatten)
+ else:
+ _grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+ if grad_output_flatten.dtype == weight.dtype:
+ _grad_input = grad_output_flatten.mm(weight)
+ else:
+ _grad_input = grad_output_flatten.float().mm(weight)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}"
+ )
+
+ return grad_input, grad_weight, grad_bias, None, None
diff --git a/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a1f57f3d1db8b19da687fe823a83f35d97d7d6
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py
@@ -0,0 +1,303 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Layer Normalization
+====================
+In this tutorial, you will write a high-performance layer normalization
+kernel that runs faster than the PyTorch implementation.
+
+In doing so, you will learn about:
+
+* Implementing backward pass in Triton.
+
+* Implementing parallel reduction in Triton.
+
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+from ._division import _stochastic_rounding
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit
+
+"""FP8 LayerNorm. Forward + Backward"""
+"""Forward: input uses 1 * 16 quantization"""
+"""Forward: output use per-tensor quantization."""
+"""Backward: input uses full-precision/BF16."""
+"""Backward: output uses full-precision/BF16"""
+"""Support 3D input, but need to first flatten to 2D to perform calculation."""
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _layer_norm_fwd_fused(
+ X, # pointer to the input
+ SX, # pointer to the scale of input
+ Y, # pointer to the output
+ SY, # pointer to the scale of output
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ eps, # epsilon to avoid division by zero
+ fp8_max,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ SX += row * scale_stride
+
+ mean = 0
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, (N2,))
+
+ # Compute mean and Variance
+ mean = tl.sum(x, axis=0) / N
+ # Compute variance
+ _var = (x - mean) * (x - mean)
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Write mean / rstd
+ tl.store(Mean + row, mean)
+ tl.store(Rstd + row, rstd)
+ # Normalize and apply linear transformation
+ x_hat = (x - mean) * rstd
+
+ # Scale calculation
+ abs_x_hat = tl.abs(x_hat)
+ max_val = tl.max(abs_x_hat, axis=0) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = scale_output.to(SY.type.element_ty)
+
+ tl.store(SY + row, scale_output)
+
+ # Write output
+ tl.store(Y + cols, x_hat, mask=cols < N)
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _layer_norm_bwd_dx_fused(
+ DX, # pointer to the input gradient
+ DY, # pointer to the output gradient
+ X, # pointer to the input
+ SX, # pointer to the input
+ noise_ptr, # noise for stochastic
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X
+ N2: tl.constexpr, # number of columns in X
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ row = tl.program_id(0)
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ mask = cols < N
+ X += row * stride
+ DY += row * stride
+ DX += row * stride
+ SX += row * scale_stride
+ # Load data to SRAM
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
+
+ mean = tl.load(Mean + row)
+ rstd = tl.load(Rstd + row)
+ # Compute dx
+ xhat = (x - mean) * rstd
+ xhat = tl.where(mask, xhat, 0.0)
+ dy = tl.where(mask, dy, 0.0)
+ c1 = tl.sum(xhat * dy, axis=0) / N
+ c2 = tl.sum(dy, axis=0) / N
+ dx = (dy - (xhat * c1 + c2)) * rstd
+
+ if STOCHASTIC:
+ # noise_ptr += row * stride
+ # noise_block_ptr = noise_ptr + cols
+ # noise = tl.load(noise_block_ptr, mask=mask, other=0.)
+
+ noise_offset = row * stride + cols
+ noise = tl.rand(0, noise_offset)
+
+ dx = _stochastic_rounding(dx, noise, e_bit, m_bit)
+
+ dx = dx.to(DX.type.element_ty)
+
+ # Write dx
+ tl.store(DX + cols, dx, mask=mask)
+
+
+def fp8_layernorm_noparam_forward(x, s_x, QB, eps, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # allocate output
+ M, N = x.shape
+ _, SN = s_x.shape
+ y = torch.empty_like(x, dtype=torch.float32)
+ s_y = torch.empty(
+ (M,), dtype=torch.bfloat16, device=x.device
+ ) # We do this because we apply per-tensor quantization for it afterwards
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device)
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ # heuristics for number of warps
+ num_warps = 8
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype]
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = N2 // QB
+ # enqueue kernel
+ _layer_norm_fwd_fused[(M,)]( #
+ x,
+ s_x,
+ y,
+ s_y,
+ mean,
+ rstd, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ eps,
+ fp8MaxValue,
+ SCALE_MIN_THRES=SCALE_MIN_THRES, #
+ num_warps=num_warps,
+ num_ctas=1,
+ )
+ # reduction
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, y.shape[-1])
+
+ return qy, s_y_max, qy_t, (mean, rstd, num_warps)
+
+
+def fp8_layernorm_noparam_backward(x, s_x, g, QB, m, v, num_warps, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ # noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # heuristics for amount of parallel reduction stream for DW/DB
+ dx = torch.empty_like(g, dtype=torch.bfloat16)
+ # enqueue kernel using forward pass heuristics
+ # also compute partial sums for DW and DB
+ M, N = g.shape
+ _, SN = s_x.shape
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = triton.next_power_of_2(SN)
+ _layer_norm_bwd_dx_fused[(M,)]( #
+ dx,
+ g,
+ x,
+ s_x,
+ noise,
+ m,
+ v, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ num_warps=num_warps,
+ )
+
+ if batched:
+ dx = dx.reshape(BS, -1, dx.shape[-1])
+
+ return dx
diff --git a/llava/model/coat/activation/real_quantization/func_quantize.py b/llava/model/coat/activation/real_quantization/func_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f63942146890a02132b431a10df39f22a65ebec4
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_quantize.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+from ._quantize import fp8_quantize
+from ._quantize_pertensor import fp8_quantize_pertensor
+
+
+class Coat_quantize_bgn(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.fp8type = self.args.fabit
+ self.layer_type = layer_type
+
+ def forward(self, input):
+ if self.training:
+ return Coat_quantize_bgn_func.apply(input, self.args.group_size, self.fp8type)
+ else:
+ return input, None, None
+
+
+class Coat_quantize_bgn_func(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, group_size, fp8type):
+ """
+ (Qoutput, Oscale) uses 1 * 16 quantization
+ """
+ Qoutput, Oscale = fp8_quantize(input, group_size, fp8type)
+ # For autograd
+ Qoutput = Qoutput.view(torch.float8_e4m3fn)
+ ctx.saved = group_size
+ return input, Qoutput, Oscale
+
+ @staticmethod
+ def backward(ctx, grad_output, Qgrad_output, Gscale):
+ """
+ (Qgrad_output, Gscale) uses 1 * 16 quantization
+ """
+ return grad_output, None, None
+
+
+class Coat_quantize_end(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.fp8type = self.args.babit
+ self.layer_type = layer_type
+
+ def forward(self, input, Qinput, Iscale):
+ if self.training:
+ return Coat_quantize_end_func.apply(input, Qinput, Iscale, self.args.group_size, self.fp8type)
+ else:
+ return input
+
+
+class Coat_quantize_end_func(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, Qinput, Iscale, group_size, fp8type):
+ """
+ (Qinput, Iscale) uses 1 * 16 quantization
+ """
+ ctx.saved = group_size, fp8type
+
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ (Qgrad_output, Gscale) uses per-tensor quantization
+ """
+
+ group_size, fp8type = ctx.saved
+ Qgrad_output, Gscale, Gscale_g16 = fp8_quantize_pertensor(grad_output, group_size, fp8type, stochastic=False)
+
+ # For autograd
+ Qgrad_output = Qgrad_output.view(torch.float8_e4m3fn)
+
+ return grad_output, Qgrad_output, Gscale_g16, None, None
diff --git a/llava/model/coat/activation/real_quantization/func_rmsnorm.py b/llava/model/coat/activation/real_quantization/func_rmsnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a8d38e2909d7a45a7792e9b1502af5aa7dee95
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_rmsnorm.py
@@ -0,0 +1,345 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Layer Normalization
+====================
+In this tutorial, you will write a high-performance layer normalization
+kernel that runs faster than the PyTorch implementation.
+
+In doing so, you will learn about:
+
+* Implementing backward pass in Triton.
+
+* Implementing parallel reduction in Triton.
+
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES
+
+"""RMSNorm Forward + Backward"""
+"""Forward: Input uses 1 * 16 group quantization"""
+"""Forward: Output uses per-tensor quantization"""
+"""Backward: Input uses full-precision/BF16."""
+"""Backward: Output uses full-precision/BF16."""
+
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _rms_norm_fwd_fused(
+ X, # pointer to the input
+ SX, # pointer to the scale of input
+ Y, # pointer to the output
+ SY, # pointer to the scale of output
+ W, # Weight
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ eps, # epsilon to avoid division by zero
+ fp8_max,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ SX += row * scale_stride
+
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+
+ # Compute variance
+ _var = x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+
+ # Write mean / rstd
+ tl.store(Rstd + row, rstd)
+
+ # Normalize and apply linear transformation
+ w = tl.load(W + cols, mask=cols < N, other=0.0)
+ x_hat = x * rstd
+ y = x_hat * w
+
+ # Scale calculation
+ abs_y = tl.abs(y)
+ max_val = tl.max(abs_y) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ tl.store(SY + row, scale_output)
+
+ # Write output
+ tl.store(Y + cols, y, mask=cols < N)
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _rms_norm_bwd_dx_fused(
+ DX, # pointer to the input gradient
+ DY, # pointer to the output gradient
+ DW,
+ X, # pointer to the input
+ SX, # pointer to the input
+ W, # weight
+ Rstd, # pointer to the 1/std
+ Lock,
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ row = tl.program_id(0)
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ mask = cols < N
+ X += row * stride
+ DY += row * stride
+ DX += row * stride
+ SX += row * scale_stride
+
+ # Offset locks and weights/biases gradient pointer for parallel reduction
+ lock_id = row % GROUP_SIZE_M
+ Lock += lock_id
+ Count = Lock + GROUP_SIZE_M
+ DW = DW + lock_id * N + cols
+
+ # Load data to SRAM
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
+
+ # Load weight
+ w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
+ rstd = tl.load(Rstd + row).to(tl.float32)
+
+ # Compute dx
+ xhat = x * rstd
+ wdy = w * dy
+ xhat = tl.where(mask, xhat, 0.0)
+ wdy = tl.where(mask, wdy, 0.0)
+ c1 = tl.sum(xhat * wdy, axis=0) / N
+ dx = (wdy - (xhat * c1)) * rstd # layer norm have c2 term, rmsnorm do not
+
+ dx = dx.to(DX.type.element_ty)
+
+ # Write dx
+ tl.store(DX + cols, dx, mask=mask)
+
+ # Accumulate partial sums for dw/db
+ partial_dw = (dy * xhat).to(w.dtype)
+ while tl.atomic_cas(Lock, 0, 1) == 1:
+ pass
+ count = tl.load(Count)
+ # First store doesn't accumulate
+ if count == 0:
+ tl.atomic_xchg(Count, 1)
+ else:
+ partial_dw += tl.load(DW, mask=mask)
+ tl.store(DW, partial_dw, mask=mask)
+ # Release the lock
+ tl.atomic_xchg(Lock, 0)
+
+
+@triton.jit
+def _rms_norm_bwd_dwdb(
+ DW, # pointer to the partial sum of weights gradient
+ FINAL_DW, # pointer to the weights gradient
+ M, # GROUP_SIZE_M
+ N, # number of columns
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+):
+ # Map the program id to the elements of DW and DB it should compute.
+ pid = tl.program_id(0)
+ cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ # Iterate through the rows of DW and DB to sum the partial sums.
+ for i in range(0, M, BLOCK_SIZE_M):
+ rows = i + tl.arange(0, BLOCK_SIZE_M)
+ mask = (rows[:, None] < M) & (cols[None, :] < N)
+ offs = rows[:, None] * N + cols[None, :]
+ dw += tl.load(DW + offs, mask=mask, other=0.0)
+ # Write the final sum to the output.
+ sum_dw = tl.sum(dw, axis=0)
+ tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
+
+
+def fp8_rmsnorm_forward(x, s_x, w, QB, eps, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # allocate output
+ M, N = x.shape
+ _, SN = s_x.shape
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+ s_y = torch.empty((M,), dtype=torch.bfloat16, device=x.device)
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ # heuristics for number of warps
+ num_warps = 8
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype]
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = N2 // QB
+
+ # import os
+ # if int(os.environ.get("LOCAL_RANK")) == 7:
+ # print(x.device, x.shape, x.dtype, s_x.shape, s_x.dtype, y.shape, y.dtype, s_y.shape, s_y.dtype, w.shape, w.dtype, rstd.shape, rstd.dtype, x.stride(0), s_x.stride(0), N, N2, SN, SN2, QB, "\n")
+ # enqueue kernel
+ _rms_norm_fwd_fused[(M,)]( #
+ x,
+ s_x,
+ y,
+ s_y,
+ w,
+ rstd, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ eps,
+ fp8MaxValue,
+ SCALE_MIN_THRES=SCALE_MIN_THRES, #
+ num_warps=num_warps,
+ num_ctas=1,
+ )
+
+ # reduction
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, y.shape[-1])
+
+ return qy, s_y_max, qy_t, (w.clone(), rstd, num_warps)
+
+
+def fp8_rmsnorm_backward(x, s_x, g, w, v, QB, num_warps):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # enqueue kernel using forward pass heuristics
+ # also compute partial sums for DW and DB
+ M, N = g.shape
+ _, SN = s_x.shape
+
+ GROUP_SIZE_M = 128
+ # heuristics for amount of parallel reduction stream for DW/DB
+ locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
+ _dw = torch.zeros((GROUP_SIZE_M, N), dtype=w.dtype, device=w.device)
+ dw = torch.empty((N,), dtype=w.dtype, device=w.device)
+
+ dx = torch.empty_like(g, dtype=torch.bfloat16)
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = triton.next_power_of_2(SN)
+ _rms_norm_bwd_dx_fused[(M,)]( #
+ dx,
+ g,
+ _dw,
+ x,
+ s_x,
+ w,
+ v,
+ locks, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ num_warps=num_warps,
+ GROUP_SIZE_M=GROUP_SIZE_M,
+ )
+
+ if batched:
+ dx = dx.reshape(BS, -1, dx.shape[-1])
+
+ grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
+ # accumulate partial sums in separate kernel
+ _rms_norm_bwd_dwdb[grid](_dw, dw, min(GROUP_SIZE_M, M), N, BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, num_ctas=1) # #
+
+ return dx, dw
diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd.py b/llava/model/coat/activation/real_quantization/gelu_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8772707f7f5e6ef9b68b9c29f2b1b329f2c3177d
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_bwd.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""GELU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses full-precision/BF16"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_backward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of GELU's backward
+ pi = float(torch.pi)
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi)
+ gelu_output = cdf + exp
+
+ gelu_output = gelu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_backward(x, s_x, g, QB, fp8type):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_backward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2b36bcb8c51d33392cd740c3dcbf6de4c9f082
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py
@@ -0,0 +1,257 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""GELU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_backward_legacy_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of GELU's backward
+ pi = float(torch.pi)
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi)
+ gelu_output = cdf + exp
+
+ gelu_output = gelu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_backward_legacy(x, s_x, g, s_g, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_backward_legacy_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ s_g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max = fp8_division(y, QB, g.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/gelu_fwd.py b/llava/model/coat/activation/real_quantization/gelu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..547c6f8339feef79870ade12a407c852631ae136
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_fwd.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""GELU Activation Forward"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+__all__ = ["fp8_gelu_forward"]
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # Actual Calculation of GeLU
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ gelu_output = cdf * input
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_forward(x, s_x, QB, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=s_x.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_forward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy.shape[-1])
+
+ return qy, s_y_max, qy_t
diff --git a/llava/model/coat/activation/real_quantization/linear.py b/llava/model/coat/activation/real_quantization/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..d643f18f51432efc3a2f80c55b73a18e99f0eb75
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/linear.py
@@ -0,0 +1,307 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from ._division import _stochastic_rounding
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+except:
+ from common import SCALE_MIN_THRES, FP8_MAX_VALUE, convert_str_to_fp8, convert_fp8_to_embit
+ from COAT.coat.activation.real_quantization._division import _stochastic_rounding
+
+import os
+import time
+
+"""Linear Layer Forward + Backward"""
+"""Input uses per-tensor quantization"""
+"""Output is full-precision/BF16 (for FlashAttention) or 1 * 16 quantization (for the rest)"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+def get_configs_io_block():
+ configs = []
+ for nstages in [3]:
+ for block_m in [128, 256]:
+ for block_n in [128, 256]:
+ for block_k in [128, 256]:
+ for nwarps in [8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+# @triton.autotune(
+# configs=get_configs_io_block(),
+# key=["M", "N", "K"],
+# )
+@triton.jit
+def _fp8matmul_kernel(
+ A,
+ B,
+ C,
+ noise_ptr, # noise for stochastic
+ M: tl.constexpr,
+ N: tl.constexpr,
+ K: tl.constexpr, #
+ stride_am,
+ stride_ak, #
+ stride_bk,
+ stride_bn, #
+ stride_cm,
+ stride_cn, ##
+ Scale_A,
+ Scale_B,
+ Scale_C,
+ stride_scm,
+ stride_scn,
+ output_quantize: tl.constexpr,
+ QB: tl.constexpr, # default to use 1 * 16 quantization
+ BIAS,
+ fp8_max: tl.constexpr,
+ e_bit: tl.constexpr,
+ m_bit: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ grid_m = tl.cdiv(M, BLOCK_M)
+ grid_n = tl.cdiv(N, BLOCK_N)
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
+ k_remaining = K - k * BLOCK_K
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
+
+ acc = tl.dot(a, b, acc)
+
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ scale_a = tl.load(Scale_A)
+ scale_b = tl.load(Scale_B)
+ scale_ab = scale_a.to(tl.float32) * scale_b.to(tl.float32)
+ # fp8 dequantize
+ acc = acc * scale_ab
+
+ if BIAS:
+ bias = tl.load(BIAS + rbn)
+ acc = acc + bias
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+
+ if output_quantize:
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N // QB, QB))
+ abs_acc = tl.abs(acc)
+ acc_max = tl.max(abs_acc, axis=2) + SCALE_MIN_THRES
+ # tl.device_print("acc_max", acc_max)
+ acc_scale = acc_max / fp8_max
+ # tl.device_print("acc_scale", acc_scale)
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB, 1))
+ acc = tl.fdiv(acc, acc_scale)
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ noise_block_ptr = noise_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ noise = tl.load(noise_block_ptr, boundary_check=(0, 1))
+ acc = _stochastic_rounding(acc, noise, e_bit, m_bit)
+
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB))
+ acc_scale = acc_scale.to(Scale_C.type.element_ty)
+ acc = acc.to(C.dtype.element_ty)
+
+ rsm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rsn = pid_n * BLOCK_N // QB + tl.arange(0, BLOCK_N // QB)
+ Scale_C = Scale_C + (rsm[:, None] * stride_scm + rsn[None, :] * stride_scn)
+
+ tl.store(C, acc, mask=mask)
+ tl.store(Scale_C, acc_scale)
+
+ else:
+ # handles write-back with reduction-splitting
+ acc = acc.to(C.dtype.element_ty)
+ tl.store(C, acc, mask=mask)
+
+
+def fp8matmul(a, b, output_quantize, scale_a, scale_b, QB, bias=None, stochastic=False):
+ # Deal with batched input
+ if len(a.shape) == 3:
+ BS, batched = a.shape[0], True
+ a = a.reshape(-1, a.shape[2])
+ else:
+ batched = False
+
+ # Check constraints.
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
+ assert a.is_contiguous(), "Matrix A must be contiguous"
+ M, K = a.shape
+ K, N = b.shape
+ fp8MaxValue = FP8_MAX_VALUE[a.dtype] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[a.dtype]
+
+ # Allocates output.
+ if output_quantize:
+ c = torch.empty((M, N), device=a.device, dtype=a.dtype)
+ # c = torch.empty((M, N), device=a.device, dtype=torch.float32)
+ scale_c = torch.empty((M, N // QB), device=a.device, dtype=torch.bfloat16)
+ else:
+ c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
+ scale_c = torch.empty(
+ (1, 1), device=a.device, dtype=torch.bfloat16
+ ) # This line is useless, equivalent to scale_c = None
+
+ if stochastic:
+ noise = torch.empty_like(c, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+ _fp8matmul_kernel[grid](
+ a,
+ b,
+ c,
+ noise, #
+ M,
+ N,
+ K, #
+ a.stride(0),
+ a.stride(1), #
+ b.stride(0),
+ b.stride(1), #
+ c.stride(0),
+ c.stride(1), #
+ scale_a,
+ scale_b,
+ scale_c,
+ scale_c.stride(0),
+ scale_c.stride(1),
+ output_quantize=output_quantize,
+ QB=QB,
+ BIAS=bias,
+ fp8_max=fp8MaxValue,
+ e_bit=e_bit,
+ m_bit=m_bit,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ BLOCK_M=128,
+ BLOCK_N=256,
+ BLOCK_K=128,
+ GROUP_M=8,
+ num_stages=3,
+ num_warps=8,
+ )
+ # Reshape output to batch
+ if batched:
+ c = c.reshape(BS, -1, N)
+ if output_quantize:
+ scale_c = scale_c.reshape(BS, -1, N // QB)
+ return c, scale_c
+ else:
+ if output_quantize:
+ scale_c = scale_c.reshape(M, N // QB)
+ return c, scale_c
+ return c
+
+
+def fp8_linear_forward(x, s, w, s_w, output_quantize, QB, bias=None):
+ assert s.numel() == 1, f"X uses per-tensor quantization in linear forward, but the scale shape is {s.shape}"
+ assert s_w.numel() == 1, f"W uses per-tensor quantization in linear forward, but the scale shape is {s_w.shape}"
+
+ w_t = w.t()
+ return fp8matmul(x, w_t, output_quantize, s, s_w, QB, bias)
+
+
+# def fp8_linear_forward(x, s, w, s_w, output_quantize, QB):
+# print("you are using the wrong linear function. ")
+# w_t = w.t()
+# if output_quantize:
+# return fp8matmul(x, w_t, True, s, s_w, QB)
+# else:
+# y = fp8matmul(x, w_t, False, s, s_w, QB)
+
+# return y
+
+
+def fp8_linear_backward(
+ x_t, s, g, s_g, g_t, w_t, s_w, QB, bias=None, stochastic=False, dgrad_quantize=False
+): # dgrad_quantize=True for backward before flashattention
+ assert s.numel() == 1, f"X uses per-tensor quantization in linear backward, but the scale shape is {s.shape}"
+ assert s_g.numel() == 1, f"G uses per-tensor quantization in linear backward, but the scale shape is {s.shape}"
+ assert s_w.numel() == 1, f"W uses per-tensor quantization in linear backward, but the scale shape is {s_w.shape}"
+
+ batched = False
+ if len(g.shape) == 3: # others must be of 2D!
+ batched = True
+ BS = g.shape[0]
+ g = g.reshape(-1, g.shape[-1])
+
+ w_t_t = w_t.t()
+ x_t_t = x_t.t()
+ if dgrad_quantize:
+ y, s_y = fp8matmul(g, w_t_t, True, s_g, s_w, QB, stochastic=stochastic)
+ else:
+ y = fp8matmul(g, w_t_t, False, s_g, s_w, QB)
+
+ w_g = fp8matmul(g_t, x_t_t, False, s_g, s, QB)
+
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ if dgrad_quantize:
+ if s_y.numel() > 1:
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+ if dgrad_quantize:
+ return y, s_y, w_g
+ else:
+ return y, w_g
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd.py b/llava/model/coat/activation/real_quantization/mul_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..27f8b2839495d6d4f92fa35d537ea75e3e04c0c2
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses full-precision/BF16"""
+"""Output1 (Gate) uses full-precision/BF16"""
+"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward(
+ x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(g, dtype=torch.bfloat16)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
+ else:
+ # Per-tensor quantization
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, qy2_t)
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..803789e83cb27330d6195a515805440af45cc528
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py
@@ -0,0 +1,374 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import _stochastic_rounding
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses 1 * 16 group quantization"""
+"""Output1 (Gate) uses 1 * 16 quantization"""
+"""Output2 (Up) uses per-tensor quantization, but should be quantized outside this function""" # Although it is per-tensor quantization, we only apply per-group quantization here, and the reduction should be performed outside this function.
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_legacy_kernel(
+ output1_ptr,
+ output1_scale_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ s_output1_stride_0,
+ s_output1_stride_1, # scale of output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ # Quantize the grad 1 - Scale calculation
+ abs_grad1 = tl.abs(grad1)
+ max_val = tl.max(abs_grad1, axis=2) + SCALE_MIN_THRES
+ scale_grad1 = max_val / fp8_max
+ scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ grad1 = tl.fdiv(grad1, scale_grad1) # do not quantize the output due to the data flow
+ scale_grad1 = scale_grad1.to(output1_scale_ptr.type.element_ty)
+ scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN))
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input1_stride_0, input1_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input1_stride_0 + offs_n[None, :] * input1_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ grad1 = _stochastic_rounding(grad1, noise, e_bit, m_bit)
+
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output1_ptr = tl.make_block_ptr(
+ base=output1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output1_stride_0, s_output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+ tl.store(scale_output1_ptr, scale_grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward_legacy(
+ x1, s_x1, x2, s_x2, g, s_g, QB, stochastic=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ y1 = torch.empty_like(g, dtype=g.dtype)
+ s_y1 = torch.empty_like(s_g, dtype=s_g.dtype)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[g.dtype]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_legacy_kernel[grid](
+ y1,
+ s_y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ s_g,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ s_y1.stride(0),
+ s_y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ s_y1 = s_y1.reshape(BS, -1, s_y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, s_y1, y2, s_y2
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..d23c26f6c86284b8d614ae0f90819bc2e4bc1718
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py
@@ -0,0 +1,337 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses full-precision/BF16"""
+"""Output1 (Gate) uses full-precision/BF16"""
+"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_silu_forward_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- recompute SiLU ---
+ # Actual Calculation of SiLU
+ sigmoid = 1 / (1.0 + libdevice.exp(-input1))
+ silu_output = sigmoid * input1
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(input1_ptr.type.element_ty)
+
+ input1 = silu_output * scale_output
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward_silu_forward(
+ x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(g, dtype=torch.bfloat16)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_silu_forward_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
+ else:
+ # Per-tensor quantization
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, qy2_t)
diff --git a/llava/model/coat/activation/real_quantization/mul_fwd.py b/llava/model/coat/activation/real_quantization/mul_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0037c7220c6f1bb0b7b9b5946a7d2553b44c7426
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_fwd.py
@@ -0,0 +1,260 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Element-wise Multiplication Forward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+fp8_max_value = {
+ torch.float8_e4m3fn: 448,
+ torch.float8_e5m2: 57344,
+}
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def fp8_mul_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # Actual Calculation of SiLU
+ mul_output = input1 * input2
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(mul_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # mul_output = tl.fdiv(mul_output, scale_output) # do not quantize the output since it should use per-tensor quantization afterwards
+ mul_output = mul_output.to(output_ptr.type.element_ty)
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ mul_output = tl.reshape(mul_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # mul_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, mul_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_mul_forward(x1, s_x1, x2, s_x2, QB, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ y = torch.empty_like(x1, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x1, dtype=s_x1.dtype)
+ fp8MaxValue = fp8_max_value[x1.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ fp8_mul_forward_kernel[grid](
+ y,
+ s_y,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x2.dtype, s_y_max)
+ qy = qy.to(x2.dtype)
+ qy_t = qy_t.to(x2.dtype)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy.shape[-1])
+
+ return qy, s_y_max, qy_t
diff --git a/llava/model/coat/activation/real_quantization/silu_bwd.py b/llava/model/coat/activation/real_quantization/silu_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..052bf66e6af6f5f3ce40e08e7b5c8c92e97bfbe3
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_bwd.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""SiLU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses full-precision / BF16"""
+"""Output uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_backward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of SiLU's backward
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid + input * sigmoid * (1 - sigmoid)
+ silu_output = silu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_backward(
+ x, s_x, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_backward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
+ else:
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, fp8type, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max, qy_t
+
+ # # Per-tensor quantization
+ # s_y_max = s_y.max()
+ # qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max)
+
+ # # Recover 2D to 3D
+ # if batched:
+ # y = y.reshape(BS, -1, y.shape[-1])
+ # qy = qy.reshape(BS, -1, qy.shape[-1])
+ # s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ # return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..88de423ad973f34718f8d44e1d2e37b3113e1737
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py
@@ -0,0 +1,248 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""SiLU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization, but should be quantized outside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_backward_legacy_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of SiLU's backward
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid + input * sigmoid * (1 - sigmoid)
+ silu_output = silu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_backward_legacy(x, s_x, g, s_g, QB, stochastic=False): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_backward_legacy_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ s_g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/real_quantization/silu_fwd.py b/llava/model/coat/activation/real_quantization/silu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6dadc42b2e25ad548ec3e2f30cc146b8ce01d0
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_fwd.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+except:
+ from common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""SiLU Activation Forward"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # Actual Calculation of SiLU
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid * input
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_forward(x, s_x, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(x, dtype=x.dtype)
+ s_y = torch.empty_like(s_x, dtype=s_x.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_forward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/utils.py b/llava/model/coat/activation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14b79e29cac93f6049c6fa81ff537bd2a5cb8fd
--- /dev/null
+++ b/llava/model/coat/activation/utils.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from collections import defaultdict
+
+import torch
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+record_memory_allocated = defaultdict(list)
diff --git a/llava/model/coat/fp8_trainer.py b/llava/model/coat/fp8_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9d68029db9fc7fbbd6af01b34d2cf342b4294f
--- /dev/null
+++ b/llava/model/coat/fp8_trainer.py
@@ -0,0 +1,626 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# In this file I only add the logic in line 411 to 415. The rest remains unchanged compared with the original Trainer Class.
+
+import math
+import os
+import shutil
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from packaging import version
+from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
+from transformers import Trainer
+from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
+from transformers.integrations import hp_params
+from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
+from transformers.integrations.tpu import tpu_spmd_dataloader
+from transformers.trainer_callback import ExportableState, TrainerState
+from transformers.trainer_pt_utils import get_model_param_count
+from transformers.trainer_utils import HPSearchBackend, TrainOutput, has_length, speed_metrics
+from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
+from transformers.utils import (
+ is_accelerate_available,
+ is_apex_available,
+ is_bitsandbytes_available,
+ is_datasets_available,
+ is_galore_torch_available,
+ is_grokadamw_available,
+ is_in_notebook,
+ is_ipex_available,
+ is_liger_kernel_available,
+ is_lomo_available,
+ is_peft_available,
+ is_safetensors_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_schedulefree_available,
+ is_torch_compile_available,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchao_available,
+ logging,
+)
+
+if is_apex_available():
+ from apex import amp
+
+if is_datasets_available():
+ import datasets
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ from torch_xla import __version__ as XLA_VERSION
+
+ IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
+ if IS_XLA_FSDPV2_POST_2_2:
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.runtime as xr
+else:
+ IS_XLA_FSDPV2_POST_2_2 = False
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+ from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+if is_safetensors_available():
+ import safetensors.torch
+
+if is_peft_available():
+ from peft import PeftModel
+
+
+if is_accelerate_available():
+ from accelerate import Accelerator
+ from accelerate import __version__ as accelerate_version
+ from accelerate import skip_first_batches
+ from accelerate.utils import (
+ DistributedDataParallelKwargs,
+ DistributedType,
+ GradientAccumulationPlugin,
+ load_fsdp_model,
+ load_fsdp_optimizer,
+ save_fsdp_model,
+ save_fsdp_optimizer,
+ )
+
+ DATA_SAMPLERS = [RandomSampler]
+ if version.parse(accelerate_version) > version.parse("0.23.0"):
+ from accelerate.data_loader import SeedableRandomSampler
+
+ DATA_SAMPLERS += [SeedableRandomSampler]
+
+ if is_deepspeed_available():
+ from accelerate.utils import DeepSpeedSchedulerWrapper
+
+if is_accelerate_available("0.28.0"):
+ from accelerate.utils import DataLoaderConfiguration
+
+from .activation.models._fp8manager import FP8Manager
+
+logger = logging.get_logger(__name__)
+
+
+class CoatFP8Trainer(Trainer):
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self.accelerator.free_memory()
+ self._train_batch_size = batch_size
+ if self.args.auto_find_batch_size:
+ if self.state.train_batch_size != self._train_batch_size:
+ from accelerate.utils import release_memory
+
+ (self.model_wrapped,) = release_memory(self.model_wrapped)
+ self.model_wrapped = self.model
+
+ # Check for DeepSpeed *after* the intial pass and modify the config
+ if self.is_deepspeed_enabled:
+ # Temporarily unset `self.args.train_batch_size`
+ original_bs = self.args.per_device_train_batch_size
+ self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
+ self.propagate_args_to_deepspeed(True)
+ self.args.per_device_train_batch_size = original_bs
+ self.state.train_batch_size = self._train_batch_size
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+ if self.is_fsdp_xla_v2_enabled:
+ train_dataloader = tpu_spmd_dataloader(train_dataloader)
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ num_train_tokens = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = (
+ self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ )
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torchrun or torch.distributed.launch (deprecated))."
+ )
+ else:
+ debug_overflow = DebugUnderflowOverflow(self.model) # noqa
+
+ delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+
+ # We need to reset the scheduler, as its parameters may be different on subsequent calls
+ if self._created_lr_scheduler:
+ self.lr_scheduler = None
+ self._created_lr_scheduler = False
+
+ if self.is_deepspeed_enabled:
+ self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+ if not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState(
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ]
+ )
+ self.state.is_hyper_param_search = trial is not None
+ self.state.train_batch_size = self._train_batch_size
+
+ # Compute absolute values for logging, eval, and save if given as ratio
+ if args.logging_steps is not None:
+ if args.logging_steps < 1:
+ self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
+ else:
+ self.state.logging_steps = args.logging_steps
+ if args.eval_steps is not None:
+ if args.eval_steps < 1:
+ self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
+ else:
+ self.state.eval_steps = args.eval_steps
+ if args.save_steps is not None:
+ if args.save_steps < 1:
+ self.state.save_steps = math.ceil(max_steps * args.save_steps)
+ else:
+ self.state.save_steps = args.save_steps
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
+
+ model = self._wrap_model(self.model_wrapped)
+
+ # as the model is wrapped, don't use `accelerator.prepare`
+ # this is for unhandled cases such as
+ # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+ use_accelerator_prepare = True if model is self.model else False
+
+ if delay_optimizer_creation:
+ if use_accelerator_prepare:
+ self._fsdp_qlora_plugin_updates()
+ self.model = self.accelerator.prepare(self.model)
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ # prepare using `accelerator` prepare
+ if use_accelerator_prepare:
+ self.model.train()
+ if hasattr(self.lr_scheduler, "step"):
+ if self.use_apex:
+ model = self.accelerator.prepare(self.model)
+ else:
+ model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+ else:
+ # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+ model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+ self.model, self.optimizer, self.lr_scheduler
+ )
+ elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ # In this case we are in DDP + LOMO, which should be supported
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if self.is_fsdp_enabled:
+ self.model = self.model_wrapped = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # ckpt loading
+ if resume_from_checkpoint is not None:
+ if self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(
+ self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
+ )
+ elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
+ self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
+ # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples:,}")
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+ if self.args.per_device_train_batch_size != self._train_batch_size:
+ logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps:,}")
+ logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ self.compare_trainer_and_checkpoint_args(self.args, self.state)
+ self._load_callback_state()
+ epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
+ if not args.ignore_data_skip:
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+ else:
+ steps_trained_in_current_epoch = 0
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(f" Continuing training from epoch {epochs_trained}")
+ logger.info(f" Continuing training from global step {self.state.global_step}")
+ if not args.ignore_data_skip:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first"
+ f" {steps_trained_in_current_epoch} batches in the first epoch."
+ )
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ if self.hp_name is not None and self._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.state.trial_name = self.hp_name(self._trial)
+ if trial is not None:
+ assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.state.trial_params = hp_params(assignments)
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0).to(args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ model.zero_grad()
+ grad_norm: Optional[float] = None
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ if args.eval_on_start:
+ self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+ total_batched_samples = 0
+ for epoch in range(epochs_trained, num_train_epochs):
+ epoch_iterator = train_dataloader
+ if hasattr(epoch_iterator, "set_epoch"):
+ epoch_iterator.set_epoch(epoch)
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
+ rng_to_sync = False
+ steps_skipped = 0
+ if steps_trained_in_current_epoch > 0:
+ epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
+ steps_skipped = steps_trained_in_current_epoch
+ steps_trained_in_current_epoch = 0
+ rng_to_sync = True
+
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+ # NOTE: FP8 related
+ if total_batched_samples % args.gradient_accumulation_steps == 0:
+ FP8Manager.is_first_microbatch = True
+ else:
+ FP8Manager.is_first_microbatch = False
+
+ total_batched_samples += 1
+
+ if self.args.include_num_input_tokens_seen:
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+ if main_input_name not in inputs:
+ logger.warning(
+ "Tried to track the number of tokens seen, however the current model is "
+ "not configured properly to know what item is the input. To fix this, add "
+ "a `main_input_name` attribute to the model class you are using."
+ )
+ else:
+ self.state.num_input_tokens_seen += (
+ torch.sum(
+ self.accelerator.gather(
+ torch.tensor(
+ inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
+ )
+ )
+ )
+ .cpu()
+ .item()
+ )
+ if rng_to_sync:
+ self._load_rng_state(resume_from_checkpoint)
+ rng_to_sync = False
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ if steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.update(1)
+ if steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+ continue
+ elif steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.close()
+ steps_trained_progress_bar = None
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ with self.accelerator.accumulate(model):
+ tr_loss_step = self.training_step(model, inputs)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_xla_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ if tr_loss.device != tr_loss_step.device:
+ raise ValueError(
+ f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
+ )
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ is_last_step_and_steps_less_than_grad_acc = (
+ steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
+ )
+
+ if (
+ total_batched_samples % args.gradient_accumulation_steps == 0
+ or
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ is_last_step_and_steps_less_than_grad_acc
+ ):
+ # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
+ # in accelerate. So, explicitly enable sync gradients to True in that case.
+ if is_last_step_and_steps_less_than_grad_acc:
+ self.accelerator.gradient_state._set_sync_gradients(True)
+
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0:
+ # deepspeed does its own clipping
+
+ if is_sagemaker_mp_enabled() and args.fp16:
+ _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif self.use_apex:
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ _grad_norm = nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer),
+ args.max_grad_norm,
+ )
+ else:
+ _grad_norm = self.accelerator.clip_grad_norm_(
+ model.parameters(),
+ args.max_grad_norm,
+ )
+
+ if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED:
+ grad_norm = model.get_global_grad_norm()
+ # In some cases the grad norm may not return a float
+ if hasattr(grad_norm, "item"):
+ grad_norm = grad_norm.item()
+ else:
+ grad_norm = _grad_norm
+
+ self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
+
+ self.optimizer.step()
+
+ self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+
+ optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
+ if optimizer_was_run:
+ # Delay optimizer scheduling until metrics are generated
+ if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step()
+
+ model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ # PyTorch/XLA relies on the data loader to insert the mark_step for
+ # each step. Since we are breaking the loop early, we need to manually
+ # insert the mark_step here.
+ if is_torch_xla_available():
+ xm.mark_step()
+ break
+ if step < 0:
+ logger.warning(
+ "There seems not to be a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_xla_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ # Wait for everyone to get here so we are sure the model has been saved by process 0.
+ if is_torch_xla_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+ dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
+
+ self._load_best_model()
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
+ train_loss = self._total_loss_scalar / effective_global_step
+
+ metrics = speed_metrics(
+ "train",
+ start_time,
+ num_samples=num_train_samples,
+ num_steps=self.state.max_steps,
+ num_tokens=num_train_tokens,
+ )
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ for checkpoint in checkpoints_sorted:
+ if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ # Wait for the checkpoint to be uploaded.
+ self._finish_current_push()
+
+ # After training we make sure to retrieve back the original forward pass method
+ # for the embedding layer by removing the forward post hook.
+ if self.neftune_noise_alpha is not None:
+ self._deactivate_neftune(self.model)
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
diff --git a/llava/model/coat/optimizer/fp8_adamw.py b/llava/model/coat/optimizer/fp8_adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..7776ef1539e07773285cf0d220811d6e73070542
--- /dev/null
+++ b/llava/model/coat/optimizer/fp8_adamw.py
@@ -0,0 +1,515 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from collections import OrderedDict, defaultdict
+from copy import deepcopy
+from itertools import chain
+from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union
+
+import qoptim_cuda
+import torch
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+from typing_extensions import ParamSpec, Self, TypeAlias
+
+StateDict: TypeAlias = Dict[str, Any]
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+
+
+class CoatAdamW(Optimizer):
+ def __init__(
+ self,
+ qargs,
+ params,
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.999),
+ eps: float = 1e-8,
+ weight_decay: float = 1e-2,
+ amsgrad: bool = False,
+ *,
+ fused: Optional[bool] = None,
+ ):
+ self.qargs = qargs
+ assert self.qargs.first_order_expansion == self.qargs.second_order_expansion
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= eps:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
+ fused=fused,
+ )
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault("amsgrad", False)
+ fused = group.setdefault("fused", None)
+ for p in group["params"]:
+ p_state = self.state.get(p, [])
+ if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
+ step_val = float(p_state["step"])
+ p_state["step"] = torch.tensor(step_val, dtype=torch.float32)
+
+ def _init_group(
+ self,
+ group,
+ params_with_grad,
+ grads,
+ amsgrad,
+ use_expansion,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ ):
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a'))
+ # print(f'Param shape: {p.shape}, {p.device}')
+
+ # State initialization
+ if len(state) == 0:
+ # This is because kernel launches are costly on CUDA and XLA.
+ state["step"] = torch.tensor(0.0)
+
+ # Should be torch.float8_e4m3fn
+ first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit]
+ second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit]
+ scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size
+
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format)
+ state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
+ if use_expansion:
+ state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format)
+ state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
+ if use_expansion:
+ state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format)
+
+ exp_avgs.append(state["exp_avg"])
+ scale_exp_avgs.append(state["scale_exp_avg"])
+ if use_expansion:
+ expand_exp_avgs.append(state["expand_exp_avg"])
+ sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"])
+ exp_avg_sqs.append(state["exp_avg_sq"])
+ scale_exp_avg_sqs.append(state["scale_exp_avg_sq"])
+ if use_expansion:
+ expand_exp_avg_sqs.append(state["expand_exp_avg_sq"])
+ sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"])
+
+ if group["amsgrad"]:
+ max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+ state_steps.append(state["step"])
+
+ @torch._disable_dynamo
+ def load_state_dict(self, state_dict: StateDict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # shallow copy, to be consistent with module API
+ state_dict = state_dict.copy()
+
+ for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
+ hook_result = pre_hook(self, state_dict)
+ if hook_result is not None:
+ state_dict = hook_result
+
+ # Validate the state_dict
+ groups = self.param_groups
+
+ # Deepcopy as we write into saved_groups later to update state
+ saved_groups = deepcopy(state_dict["param_groups"])
+
+ if len(groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of " "parameter groups")
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = dict(
+ zip(
+ chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)
+ )
+ )
+
+ def _cast(param, value, param_id=None, param_groups=None, key=None):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key)
+ elif isinstance(value, dict):
+ return {
+ k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()
+ }
+ elif isinstance(value, Iterable):
+ return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+ state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ for post_hook in self._optimizer_load_state_dict_post_hooks.values():
+ post_hook(self)
+
+ @staticmethod
+ def _process_value_according_to_param_policy(
+ param: torch.Tensor,
+ value: torch.Tensor,
+ param_id: int,
+ param_groups: List[Dict[Any, Any]],
+ key: Hashable = None,
+ ) -> torch.Tensor:
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
+ # UNLESS fused or capturable, see note [special device hosting for step]
+ fused = False
+ capturable = False
+ assert param_groups is not None
+ for pg in param_groups:
+ if param_id in pg["params"]:
+ fused = pg["fused"] if "fused" in pg else False
+ capturable = pg["capturable"] if "capturable" in pg else False
+ break
+ if key == "step":
+ if capturable or fused:
+ return value.to(dtype=torch.float32, device=param.device)
+ else:
+ return value
+ else:
+ assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32]
+ return value.to(device=param.device) # do not cast optimizer states
+ # if param.is_floating_point():
+ # return value.to(dtype=param.dtype, device=param.device)
+ # else:
+ # return value.to(device=param.device)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Perform a single optimization step.
+
+ Args:
+ closure (Callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ self._cuda_graph_capture_health_check()
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ scale_exp_avgs = []
+ expand_exp_avgs = []
+ sqrt_minmax_exp_avgs = []
+ exp_avg_sqs = []
+ scale_exp_avg_sqs = []
+ expand_exp_avg_sqs = []
+ sqrt_minmax_exp_avg_sqs = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group["amsgrad"]
+ use_expansion = self.qargs.first_order_expansion in ["expansion", "true"]
+ beta1, beta2 = group["betas"]
+
+ self._init_group(
+ group,
+ params_with_grad,
+ grads,
+ amsgrad,
+ use_expansion,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ )
+
+ Coatadamw(
+ self.qargs,
+ params_with_grad,
+ grads,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ use_expansion=use_expansion,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group["lr"],
+ weight_decay=group["weight_decay"],
+ eps=group["eps"],
+ qgroup_size=self.qargs.qgroup_size,
+ expand_min=self.qargs.expand_min,
+ fused=group["fused"],
+ grad_scale=getattr(self, "grad_scale", None),
+ found_inf=getattr(self, "found_inf", None),
+ )
+
+ return loss
+
+
+def Coatadamw(
+ qargs,
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ scale_exp_avgs: List[Tensor],
+ expand_exp_avgs: List[Tensor],
+ sqrt_minmax_exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ scale_exp_avg_sqs: List[Tensor],
+ expand_exp_avg_sqs: List[Tensor],
+ sqrt_minmax_exp_avg_sqs: List[Tensor],
+ max_exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
+ # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+ fused: Optional[bool] = None,
+ grad_scale: Optional[Tensor] = None,
+ found_inf: Optional[Tensor] = None,
+ *,
+ amsgrad: bool,
+ use_expansion: bool,
+ beta1: float,
+ beta2: float,
+ lr: Union[float, Tensor],
+ weight_decay: float,
+ eps: float,
+ qgroup_size: int,
+ expand_min: int,
+):
+ r"""Functional API that performs AdamW algorithm computation.
+
+ See :class:`~torch.optim.AdamW` for details.
+ """
+ if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
+ raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
+
+ func = _single_tensor_Coatadamw
+
+ func(
+ qargs,
+ params,
+ grads,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ use_expansion=use_expansion,
+ beta1=beta1,
+ beta2=beta2,
+ lr=lr,
+ weight_decay=weight_decay,
+ eps=eps,
+ qgroup_size=qgroup_size,
+ expand_min=expand_min,
+ grad_scale=grad_scale,
+ found_inf=found_inf,
+ )
+
+
+def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference
+ if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
+ return x.sqrt()
+ else:
+ return sqrt(x)
+
+
+def _single_tensor_Coatadamw(
+ qargs,
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ scale_exp_avgs: List[Tensor],
+ expand_exp_avgs: List[Tensor],
+ sqrt_minmax_exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ scale_exp_avg_sqs: List[Tensor],
+ expand_exp_avg_sqs: List[Tensor],
+ sqrt_minmax_exp_avg_sqs: List[Tensor],
+ max_exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ grad_scale: Optional[Tensor],
+ found_inf: Optional[Tensor],
+ *,
+ amsgrad: bool,
+ use_expansion: bool,
+ beta1: float,
+ beta2: float,
+ lr: Union[Tensor, float],
+ weight_decay: float,
+ eps: float,
+ qgroup_size: int,
+ expand_min: int,
+):
+
+ assert grad_scale is None and found_inf is None
+
+ if torch.jit.is_scripting():
+ # this assert is due to JIT being dumb and not realizing that the ops below
+ # have overloads to handle both float and Tensor lrs, so we just assert it's
+ # a float since most people using JIT are using floats
+ assert isinstance(lr, float)
+
+ for i, param in enumerate(params):
+ grad = grads[i]
+ # First order
+ exp_avg = exp_avgs[i]
+ scale_exp_avg = scale_exp_avgs[i]
+ # Second order
+ exp_avg_sq = exp_avg_sqs[i]
+ scale_exp_avg_sq = scale_exp_avg_sqs[i]
+ step_t = state_steps[i]
+
+ # print(len(exp_avg.unique()), len(exp_avg_sq.unique()))
+ # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a'))
+
+ # update step
+ step_t += 1
+ step = int(step_t.item())
+
+ # Perform Optimizer Step
+ if use_expansion:
+ expand_exp_avg = expand_exp_avgs[i]
+ sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i]
+ expand_exp_avg_sq = expand_exp_avg_sqs[i]
+ sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i]
+
+ qoptim_cuda.fp8_adamw_expand_step(
+ param,
+ grad,
+ exp_avg,
+ scale_exp_avg,
+ expand_exp_avg,
+ sqrt_minmax_exp_avg,
+ exp_avg_sq,
+ scale_exp_avg_sq,
+ expand_exp_avg_sq,
+ sqrt_minmax_exp_avg_sq,
+ beta1,
+ beta2,
+ lr,
+ weight_decay,
+ eps,
+ step,
+ qgroup_size,
+ expand_min,
+ )
+
+ else:
+ qoptim_cuda.fp8_adamw_step(
+ param,
+ grad,
+ exp_avg,
+ scale_exp_avg,
+ exp_avg_sq,
+ scale_exp_avg_sq,
+ beta1,
+ beta2,
+ lr,
+ weight_decay,
+ eps,
+ step,
+ qgroup_size,
+ )
diff --git a/llava/model/coat/optimizer/kernels/bindings.cpp b/llava/model/coat/optimizer/kernels/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..aa3dbf8b145c263000ce3b7151bb73714c55f655
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/bindings.cpp
@@ -0,0 +1,10 @@
+#include
+
+#include "include/fp8_adamw.h"
+#include "include/fp8_adamw_expand.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fp8_adamw_step", &FP8_AdamW, "Update the quantized optimizer states");
+ m.def("fp8_adamw_expand_step", &FP8_AdamW_expand,
+ "Update the quantized optimizer states, use polynomial expander");
+}
diff --git a/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..0d6753b8a2e35f20ec2070c6bce2c9e4acefa0d7
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e40fca6d032a0a3094e50a4fdb04e78998819de2652f0a82e49f49225dd46ba9
+size 235960
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o
new file mode 100644
index 0000000000000000000000000000000000000000..be1ed8ce2592d7ce94e249fd52af112a39de1af2
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daebebfbb9d1a3dbb5e22acb50af317aa74b437f7aa965bd959b19010c5fd144
+size 244976
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..837c2b25b4a1759b529b7e5bf2f74a781a60b73c
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac9c6832be5dccfce3f25ae735c67459afffd12f8f747b5018b3121da3dfd2d7
+size 215440
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o
new file mode 100644
index 0000000000000000000000000000000000000000..b6c5c480b055b909abffc646f69cd563c615cc47
Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o differ
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..ec6051ec3846f04b0f2c0f3bbd6abce6a626c3b5
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:799978e9ba5a24e75b58d0e218cf522874bb983d15fee15adb02d9ac0f2d7a10
+size 216240
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o
new file mode 100644
index 0000000000000000000000000000000000000000..b29913e078a46679c6aedca63accff5ace9ca5b6
Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o differ
diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile
@@ -0,0 +1,6 @@
+all:
+ nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90
+run:
+ ./nvcc_qoptim
+clean:
+ rm -f nvcc_qoptim
diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu
new file mode 100644
index 0000000000000000000000000000000000000000..79f370e76aaf3b99fc53c7fe053731dc9358729f
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu
@@ -0,0 +1,478 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+#define QGROUPSIZE 128
+#define QUANT_MIN_VAL 1e-20
+
+template
+inline float fp8_dtype_max(const T &variable) {
+ if (std::is_same::value) {
+ return 448;
+ } else if (std::is_same::value) {
+ return 57344;
+ } else {
+ throw "Unsupported data format";
+ }
+}
+
+typedef enum { fp8_adamw } myCsrcKernels;
+
+void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg,
+ float *fp_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int M,
+ int N) {
+ for (int idx = 0; idx < M * N; idx++) {
+ fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx];
+ fp_exp_avg_sq[idx] =
+ beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx];
+
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ float denom =
+ (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1;
+ float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ }
+}
+
+template
+void printFloatArrayToFile(T *array, int M, int N,
+ const std::string &outputFileName) {
+ std::ofstream outputFile(outputFileName);
+ if (!outputFile.is_open()) {
+ std::cout << "Failed to open the file." << std::endl;
+ return;
+ }
+
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ int index = i * N + j;
+ outputFile << std::setw(10) << std::right << std::fixed
+ << std::setprecision(6) << (float)array[index] << " ";
+ if (j == N - 1) {
+ outputFile << "\n";
+ }
+ }
+ }
+}
+
+template
+__global__ void fp8_adamw_csrc(
+ scalar_t *__restrict__ params, scalar_t *__restrict__ grads,
+ __nv_fp8_e4m3 *__restrict__ exp_avg, float *__restrict__ scale_exp_avg,
+ float *__restrict__ expand_exp_avg, float *__restrict__ sqrtminmax_exp_avg,
+ __nv_fp8_e4m3 *__restrict__ exp_avg_sq,
+ float *__restrict__ scale_exp_avg_sq, float *__restrict__ expand_exp_avg_sq,
+ float *__restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size, int expand_min,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg = sign_exp_avg *
+ powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) *
+ sqrtminmax_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+ float_exp_avg_sq =
+ powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) *
+ sqrtminmax_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedFirstMinVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ __shared__ float sharedSecondMinVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float firstMinVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+ float secondMinVal = fabsf(float_exp_avg_sq);
+ // Special Handel
+ if (idx >= total_elements) {
+ firstMinVal = __int_as_float(0x7f7fffff);
+ secondMinVal = __int_as_float(0x7f7fffff);
+ }
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) {
+ sharedFirstMaxVal[wid] = firstMaxVal;
+ sharedFirstMinVal[wid] = firstMinVal;
+ sharedSecondMaxVal[wid] = secondMaxVal;
+ sharedSecondMinVal[wid] = secondMinVal;
+ }
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmin_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ __shared__ float shared_absmin_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ firstMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ secondMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceFirstMinVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ float reduceSecondMinVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ }
+ if (lane == 0) {
+ shared_absmax_exp_avg = firstMaxVal;
+ shared_absmin_exp_avg = firstMinVal;
+ shared_absmax_exp_avg_sq = secondMaxVal;
+ shared_absmin_exp_avg_sq = secondMinVal;
+ }
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ // scaling factor before expanding
+ float fp8MaxVal = 448;
+
+ // dynamic exponent quantization part
+ firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL;
+ secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+ secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL;
+
+ // calculate the ratio and make the scale to center
+ float firstRatio = firstMaxVal / firstMinVal;
+ float secondRatio = secondMaxVal / secondMinVal;
+ float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal);
+ float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal);
+
+ // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal,
+ // float_exp_avg);
+
+ // since we use x^k expander, calculate the optimal expanding factor
+ float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2;
+ float firstExp =
+ floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) /
+ expand_min; // expand_min is set to 8 for example, then the firstExp is
+ // the multiple of 1/8
+ float secondExp =
+ floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) /
+ expand_min;
+
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg =
+ sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp);
+ float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp);
+
+ // correspondingly, change the scaling factor
+ float new_scale_exp_avg =
+ powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal;
+ float new_scale_exp_avg_sq =
+ powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ expand_exp_avg[scale_idx] = firstExp;
+ expand_exp_avg_sq[scale_idx] = secondExp;
+ sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax;
+ sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax;
+ }
+}
+
+template
+void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg,
+ float *scale_exp_avg, float *expand_exp_avg,
+ float *sqrtminmax_exp_avg, __nv_fp8_e4m3 *exp_avg_sq,
+ float *scale_exp_avg_sq, float *expand_exp_avg_sq,
+ float *sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size,
+ int expand_min, int M, int N) {
+ if (algo == fp8_adamw) {
+ const int block_dim = 128;
+ int grid_dim = (M * N + qgroup_size - 1) / block_dim;
+ const dim3 gridDim(grid_dim);
+ const dim3 blockDim(block_dim);
+ printf("Yes!\n");
+ fp8_adamw_csrc<<>>(
+ params, grads, exp_avg, scale_exp_avg, expand_exp_avg,
+ sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, expand_exp_avg_sq,
+ sqrtminmax_exp_avg_sq, beta1, beta2, lr, wd, eps, step, qgroup_size,
+ expand_min, M * N, int(floor(M * N / 128.)));
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in kernel launch: "
+ << cudaGetErrorString(error) << std::endl;
+ return;
+ }
+ printf("Finish!\n");
+ }
+}
+
+float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *,
+ float *, float *, float *,
+ __nv_fp8_e4m3 *, float *, float *,
+ float *, // tensor input
+ float, float, float, float, float, int,
+ int, int, // float and int input
+ int, int), // M and N
+ int M, int N) {
+ size_t size_param = M * N * sizeof(float);
+ size_t size_optim = M * N * sizeof(__nv_fp8_e4m3);
+ size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float);
+
+ // host tensor
+ float *h_p, *h_g;
+ __nv_fp8_e4m3 *h_m, *h_v;
+ float *h_sm, *h_sv;
+ float *h_fp_m, *h_fp_v;
+ float *h_cpd_m, *h_cpd_v;
+ float *h_sqrtmm_m, *h_sqrtmm_v;
+
+ // device tensor
+ float *d_p, *d_g;
+ __nv_fp8_e4m3 *d_m, *d_v;
+ float *d_sm, *d_sv;
+ float *d_cpd_m, *d_cpd_v;
+ float *d_sqrtmm_m, *d_sqrtmm_v;
+
+ // device tensor transfer to host
+ float *hd_p, *hd_g;
+ __nv_fp8_e4m3 *hd_m, *hd_v;
+ float *hd_sm, *hd_sv;
+ float *hd_fp_m, *hd_fp_v;
+ float *hd_cpd_m, *hd_cpd_v;
+ float *hd_sqrtmm_m, *hd_sqrtmm_v;
+
+ h_p = (float *)malloc(size_param);
+ h_g = (float *)malloc(size_param);
+ h_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_sm = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_cpd_m = (float *)malloc(size_scale);
+ h_cpd_v = (float *)malloc(size_scale);
+ h_sqrtmm_m = (float *)malloc(size_scale);
+ h_sqrtmm_v = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_fp_m = (float *)malloc(size_param);
+ h_fp_v = (float *)malloc(size_param);
+ cudaMalloc(&d_p, size_param);
+ cudaMalloc(&d_g, size_param);
+ cudaMalloc(&d_m, size_optim);
+ cudaMalloc(&d_v, size_optim);
+ cudaMalloc(&d_sm, size_scale);
+ cudaMalloc(&d_sv, size_scale);
+ cudaMalloc(&d_cpd_m, size_scale);
+ cudaMalloc(&d_cpd_v, size_scale);
+ cudaMalloc(&d_sqrtmm_m, size_scale);
+ cudaMalloc(&d_sqrtmm_v, size_scale);
+ hd_p = (float *)malloc(size_param);
+ hd_g = (float *)malloc(size_param);
+ hd_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_sm = (float *)malloc(size_scale);
+ hd_sv = (float *)malloc(size_scale);
+ hd_fp_m = (float *)malloc(size_param);
+ hd_fp_v = (float *)malloc(size_param);
+ hd_cpd_m = (float *)malloc(size_scale);
+ hd_cpd_v = (float *)malloc(size_scale);
+ hd_sqrtmm_m = (float *)malloc(size_scale);
+ hd_sqrtmm_v = (float *)malloc(size_scale);
+
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data copy: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ srand(0);
+ // random initialization for CPU tensor
+ for (int i = 0; i < M * N; i++) {
+ h_p[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_g[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ }
+ for (int i = 0; i < int(ceilf(M * N / 128.)); i++) {
+ h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_cpd_m[i] = 2;
+ h_cpd_v[i] = 3. / 8.;
+ h_sqrtmm_m[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sqrtmm_v[i] = (float)(rand() / (float(RAND_MAX) / 10));
+
+ printf("scale is %f\n", h_sm[i]);
+ }
+ for (int i = 0; i < M * N; i++) {
+ h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))];
+ h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))];
+ }
+ float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8;
+ int step = 100, qgroup_size = 128, expand_min = 16;
+
+ printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt");
+
+ cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_cpd_m, h_cpd_m, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_cpd_v, h_cpd_v, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sqrtmm_m, h_sqrtmm_m, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sqrtmm_v, h_sqrtmm_v, size_scale, cudaMemcpyHostToDevice);
+
+ fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data initialization: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ myGPUKernel(d_p, d_g, d_m, d_sm, d_cpd_m, d_sqrtmm_m, d_v, d_sv, d_cpd_v,
+ d_sqrtmm_v, beta1, beta2, lr, wd, eps, step, qgroup_size,
+ expand_min, M, N);
+
+ cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_cpd_m, d_cpd_m, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_cpd_v, d_cpd_v, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sqrtmm_m, d_sqrtmm_m, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sqrtmm_v, d_sqrtmm_v, size_scale, cudaMemcpyDeviceToHost);
+
+ for (int i = 0; i < M * N; i++) {
+ hd_fp_m[i] = pow((float)hd_m[i] * hd_sm[int(floor(i / 128.))],
+ 1 / hd_cpd_m[int(floor(i / 128.))]) *
+ hd_sqrtmm_m[int(floor(i / 128.))];
+ hd_fp_v[i] = pow((float)hd_v[i] * hd_sv[int(floor(i / 128.))],
+ 1 / hd_cpd_v[int(floor(i / 128.))]) *
+ hd_sqrtmm_v[int(floor(i / 128.))];
+ }
+ printFloatArrayToFile(h_p, M, N, "CPU_param.txt");
+ printFloatArrayToFile(hd_p, M, N, "GPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "CPU_grad.txt");
+ printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt");
+ printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt");
+ printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt");
+ printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt");
+ printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt");
+ printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt");
+ printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt");
+
+ printFloatArrayToFile(hd_cpd_m, 1, int(ceilf(M * N / 128.)), "GPU_cpd_m.txt");
+ printFloatArrayToFile(hd_cpd_v, 1, int(ceilf(M * N / 128.)), "GPU_cpd_v.txt");
+ printFloatArrayToFile(hd_sqrtmm_m, 1, int(ceilf(M * N / 128.)),
+ "GPU_sqrtmm_m.txt");
+ printFloatArrayToFile(hd_sqrtmm_v, 1, int(ceilf(M * N / 128.)),
+ "GPU_sqrtmm_v.txt");
+
+ return 0.;
+}
+
+int main() {
+ const int M = 1, N = 7;
+ float max_error = testMaxError(myKernelLauncher, M, N);
+}
diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile
@@ -0,0 +1,6 @@
+all:
+ nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90
+run:
+ ./nvcc_qoptim
+clean:
+ rm -f nvcc_qoptim
diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dde952654b71f11dc6de051a7089df38e14c6b57
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu
@@ -0,0 +1,354 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+#define QGROUPSIZE 128
+#define QUANT_MIN_VAL 1e-20
+
+template
+inline float fp8_dtype_max(const T &variable) {
+ if (std::is_same::value) {
+ return 448;
+ } else if (std::is_same::value) {
+ return 57344;
+ } else {
+ throw "Unsupported data format";
+ }
+}
+
+typedef enum { fp8_adamw } myCsrcKernels;
+
+void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg,
+ float *fp_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int M,
+ int N) {
+ for (int idx = 0; idx < M * N; idx++) {
+ fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx];
+ fp_exp_avg_sq[idx] =
+ beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx];
+
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ float denom =
+ (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1;
+ float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ }
+}
+
+template
+void printFloatArrayToFile(T *array, int M, int N,
+ const std::string &outputFileName) {
+ std::ofstream outputFile(outputFileName);
+ if (!outputFile.is_open()) {
+ std::cout << "Failed to open the file." << std::endl;
+ return;
+ }
+
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ int index = i * N + j;
+ outputFile << std::setw(10) << std::right << std::fixed
+ << std::setprecision(6) << (float)array[index] << " ";
+ if (j == N - 1) {
+ outputFile << "\n";
+ }
+ }
+ }
+}
+
+template
+__global__ void fp8_adamw_csrc(scalar_t *__restrict__ params,
+ scalar_t *__restrict__ grads,
+ __nv_fp8_e4m3 *__restrict__ exp_avg,
+ float *__restrict__ scale_exp_avg,
+ __nv_fp8_e4m3 *__restrict__ exp_avg_sq,
+ float *__restrict__ scale_exp_avg_sq,
+ float beta1, float beta2, float lr, float wd,
+ float eps, int step, int qgroup_size,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal;
+ if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal;
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ }
+ if (lane == 0) shared_absmax_exp_avg = firstMaxVal;
+ if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal;
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ float fp8MaxVal = 448;
+
+ shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+
+ float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal;
+ float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ }
+}
+
+template
+void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg,
+ float *scale_exp_avg, __nv_fp8_e4m3 *exp_avg_sq,
+ float *scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size,
+ int M, int N) {
+ if (algo == fp8_adamw) {
+ const int block_dim = 128;
+ int grid_dim = (M * N + qgroup_size - 1) / block_dim;
+ const dim3 gridDim(grid_dim);
+ const dim3 blockDim(block_dim);
+ printf("Yes!\n");
+ fp8_adamw_csrc<<>>(
+ params, grads, exp_avg, scale_exp_avg, exp_avg_sq, scale_exp_avg_sq,
+ beta1, beta2, lr, wd, eps, step, qgroup_size, M * N,
+ int(floor(M * N / 128.)));
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in kernel launch: "
+ << cudaGetErrorString(error) << std::endl;
+ return;
+ }
+ printf("Finish!\n");
+ }
+}
+
+float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *,
+ float *, __nv_fp8_e4m3 *, float *, float,
+ float, float, float, float, int, int,
+ int, int),
+ int M, int N) {
+ size_t size_param = M * N * sizeof(float);
+ size_t size_optim = M * N * sizeof(__nv_fp8_e4m3);
+ size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float);
+
+ // host tensor
+ float *h_p, *h_g;
+ __nv_fp8_e4m3 *h_m, *h_v;
+ float *h_sm, *h_sv;
+ float *h_fp_m, *h_fp_v;
+
+ // device tensor
+ float *d_p, *d_g;
+ __nv_fp8_e4m3 *d_m, *d_v;
+ float *d_sm, *d_sv;
+
+ // device tensor transfer to host
+ float *hd_p, *hd_g;
+ __nv_fp8_e4m3 *hd_m, *hd_v;
+ float *hd_sm, *hd_sv;
+ float *hd_fp_m, *hd_fp_v;
+
+ h_p = (float *)malloc(size_param);
+ h_g = (float *)malloc(size_param);
+ h_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_sm = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_fp_m = (float *)malloc(size_param);
+ h_fp_v = (float *)malloc(size_param);
+ cudaMalloc(&d_p, size_param);
+ cudaMalloc(&d_g, size_param);
+ cudaMalloc(&d_m, size_optim);
+ cudaMalloc(&d_v, size_optim);
+ cudaMalloc(&d_sm, size_scale);
+ cudaMalloc(&d_sv, size_scale);
+ hd_p = (float *)malloc(size_param);
+ hd_g = (float *)malloc(size_param);
+ hd_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_sm = (float *)malloc(size_scale);
+ hd_sv = (float *)malloc(size_scale);
+ hd_fp_m = (float *)malloc(size_param);
+ hd_fp_v = (float *)malloc(size_param);
+
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data copy: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ srand(0);
+ // random initialization for CPU tensor
+ for (int i = 0; i < M * N; i++) {
+ h_p[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_g[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ }
+ for (int i = 0; i < int(ceilf(M * N / 128.)); i++) {
+ h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ printf("scale is %f\n", h_sm[i]);
+ }
+ for (int i = 0; i < M * N; i++) {
+ h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))];
+ h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))];
+ }
+ float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8;
+ int step = 100, qgroup_size = 128;
+
+ printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt");
+
+ cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice);
+
+ fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data initialization: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ myGPUKernel(d_p, d_g, d_m, d_sm, d_v, d_sv, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost);
+
+ for (int i = 0; i < M * N; i++) {
+ hd_fp_m[i] = (float)hd_m[i] * hd_sm[int(floor(i / 128.))];
+ hd_fp_v[i] = (float)hd_v[i] * hd_sv[int(floor(i / 128.))];
+ }
+ printFloatArrayToFile(h_p, M, N, "CPU_param.txt");
+ printFloatArrayToFile(hd_p, M, N, "GPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "CPU_grad.txt");
+ printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt");
+ printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt");
+ printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt");
+ printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt");
+ printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt");
+ printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt");
+ printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt");
+
+ return 0.;
+}
+
+int main() {
+ const int M = 1, N = 7;
+ float max_error = testMaxError(myKernelLauncher, M, N);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..57190c713309a19eafe35b4651e55d7c5bd576f7
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp
@@ -0,0 +1,26 @@
+#include
+#include
+
+void FP8_AdamW_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size // other parameters
+);
+
+void FP8_AdamW(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size) { // other parameters
+
+ FP8_AdamW_cuda(params, grads, exp_avg, scale_exp_avg, exp_avg_sq,
+ scale_exp_avg_sq, beta1, beta2, lr, wd, eps, step,
+ qgroup_size);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9e652845af5d228237d3c8df758df1a4fe94bcd4
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu
@@ -0,0 +1,163 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define QUANT_MIN_VAL 1e-20
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+
+template
+__global__ void fp8_adamw_cuda_kernel(
+ scalar_t* __restrict__ params, scalar_t* __restrict__ grads,
+ __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg,
+ __nv_fp8_e4m3* __restrict__ exp_avg_sq,
+ float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int total_elements,
+ int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal;
+ if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal;
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ }
+ if (lane == 0) shared_absmax_exp_avg = firstMaxVal;
+ if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal;
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ float fp8MaxVal = 448;
+
+ shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+
+ float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal;
+ float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ }
+}
+
+void FP8_AdamW_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size) { // other parameters
+
+ // CUDA Blocks
+ int total_elements = params.numel();
+ int total_scale_elements = scale_exp_avg.numel();
+ AT_ASSERTM(qgroup_size == 128,
+ "Only Support 128 per-group quantization currently");
+ const int block_dim = 128; // This should equal to the qgroup_size
+ int grid_dim = (total_elements + qgroup_size - 1) / block_dim;
+ AT_ASSERTM(grid_dim == scale_exp_avg.numel());
+ AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel());
+ const dim3 blocks(grid_dim);
+
+ // Execution
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] {
+ fp8_adamw_cuda_kernel<<>>(
+ params.data_ptr(), grads.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg.data_ptr(),
+ scale_exp_avg.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(),
+ scale_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, step,
+ qgroup_size, total_elements, total_scale_elements);
+ }));
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d50f96dc9d4dc186b660d6b6ff56ff9df1c2df89
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp
@@ -0,0 +1,34 @@
+#include
+#include
+
+void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min // other parameters
+);
+
+void FP8_AdamW_expand(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min) { // other parameters
+
+ FP8_AdamW_expand_cuda(params, grads, exp_avg, scale_exp_avg, expand_exp_avg,
+ sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq,
+ expand_exp_avg_sq, sqrtminmax_exp_avg_sq, beta1, beta2,
+ lr, wd, eps, step, qgroup_size, expand_min);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..6130e3716fe5d73bf66cc6f915c3bf99a41f785b
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu
@@ -0,0 +1,247 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define QUANT_MIN_VAL 1e-20
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+
+template
+__global__ void fp8_adamw_cuda_expand_kernel(
+ scalar_t* __restrict__ params, scalar_t* __restrict__ grads,
+ __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg,
+ float* __restrict__ expand_exp_avg, float* __restrict__ sqrtminmax_exp_avg,
+ __nv_fp8_e4m3* __restrict__ exp_avg_sq,
+ float* __restrict__ scale_exp_avg_sq, float* __restrict__ expand_exp_avg_sq,
+ float* __restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size, int expand_min,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg = sign_exp_avg *
+ powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) *
+ sqrtminmax_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+ float_exp_avg_sq =
+ powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) *
+ sqrtminmax_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedFirstMinVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ __shared__ float sharedSecondMinVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float firstMinVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+ float secondMinVal = fabsf(float_exp_avg_sq);
+ // Special Handel
+ if (idx >= total_elements) {
+ firstMinVal = __int_as_float(0x7f7fffff);
+ secondMinVal = __int_as_float(0x7f7fffff);
+ }
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) {
+ sharedFirstMaxVal[wid] = firstMaxVal;
+ sharedFirstMinVal[wid] = firstMinVal;
+ sharedSecondMaxVal[wid] = secondMaxVal;
+ sharedSecondMinVal[wid] = secondMinVal;
+ }
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmin_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ __shared__ float shared_absmin_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ firstMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ secondMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceFirstMinVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ float reduceSecondMinVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ }
+ if (lane == 0) {
+ shared_absmax_exp_avg = firstMaxVal;
+ shared_absmin_exp_avg = firstMinVal;
+ shared_absmax_exp_avg_sq = secondMaxVal;
+ shared_absmin_exp_avg_sq = secondMinVal;
+ }
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ // scaling factor before expanding
+ float fp8MaxVal = 448;
+
+ // dynamic exponent quantization part
+ firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL;
+ secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+ secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL;
+
+ // calculate the ratio and make the scale to center
+ float firstRatio = firstMaxVal / firstMinVal;
+ float secondRatio = secondMaxVal / secondMinVal;
+ float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal);
+ float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal);
+
+ // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal,
+ // float_exp_avg);
+
+ // since we use x^k expander, calculate the optimal expanding factor
+ float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2;
+ float firstExp =
+ floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) /
+ expand_min; // expand_min is set to 8 for example, then the firstExp is
+ // the multiple of 1/8
+ float secondExp =
+ floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) /
+ expand_min;
+
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg =
+ sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp);
+ float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp);
+
+ // correspondingly, change the scaling factor
+ float new_scale_exp_avg =
+ powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal;
+ float new_scale_exp_avg_sq =
+ powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ expand_exp_avg[scale_idx] = firstExp;
+ expand_exp_avg_sq[scale_idx] = secondExp;
+ sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax;
+ sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax;
+ }
+}
+
+void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size,
+ int expand_min) { // other parameters
+
+ // CUDA Blocks
+ int total_elements = params.numel();
+ int total_scale_elements = scale_exp_avg.numel();
+ AT_ASSERTM(qgroup_size == 128,
+ "Only Support 128 per-group quantization currently");
+ const int block_dim = 128; // This should equal to the qgroup_size
+ int grid_dim = (total_elements + qgroup_size - 1) / block_dim;
+ AT_ASSERTM(grid_dim == scale_exp_avg.numel());
+ AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel());
+ const dim3 blocks(grid_dim);
+
+ // Execution
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] {
+ fp8_adamw_cuda_expand_kernel<<>>(
+ params.data_ptr(), grads.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg.data_ptr(),
+ scale_exp_avg.data_ptr(), expand_exp_avg.data_ptr(),
+ sqrtminmax_exp_avg.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(),
+ scale_exp_avg_sq.data_ptr(),
+ expand_exp_avg_sq.data_ptr(),
+ sqrtminmax_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps,
+ step, qgroup_size, expand_min, total_elements,
+ total_scale_elements);
+ }));
+}
diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h
new file mode 100644
index 0000000000000000000000000000000000000000..c04970da0dd564ea27e1e6203a38206d9506cac8
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h
@@ -0,0 +1,14 @@
+#ifndef FP8_ADAMW
+#define FP8_ADAMW
+
+#include
+
+void FP8_AdamW(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size);
+
+#endif // FP8_ADAMW
diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h
new file mode 100644
index 0000000000000000000000000000000000000000..e1fb69825933154b1a867efaecd54cbdf100201b
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h
@@ -0,0 +1,18 @@
+#ifndef FP8_ADAMW_CONPAND
+#define FP8_ADAMW_CONPAND
+
+#include
+
+void FP8_AdamW_expand(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min);
+
+#endif // FP8_ADAMW_CONPAND
diff --git a/llava/model/coat/optimizer/kernels/setup.py b/llava/model/coat/optimizer/kernels/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8453efeaa472f3b094a2c157acbebf4c571eb09d
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/setup.py
@@ -0,0 +1,56 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name="qoptim_cuda",
+ ext_modules=[
+ CUDAExtension(
+ name="qoptim_cuda",
+ sources=[
+ "fp8_adamw_cuda.cpp",
+ "fp8_adamw_cuda_kernel.cu",
+ "fp8_adamw_expand_cuda.cpp",
+ "fp8_adamw_expand_cuda_kernel.cu",
+ "bindings.cpp",
+ ],
+ # include_dirs=[
+ # 'include'
+ # ],
+ extra_compile_args={
+ "nvcc": [
+ "-O3",
+ "-std=c++17",
+ "-gencode=arch=compute_90,code=compute_90",
+ "-DTORCH_USE_CUDA_DSA",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "-U__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ },
+ ),
+ ],
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/llava/model/configuration_llava.py b/llava/model/configuration_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..b413354ab70581faf619a632e1934a1a956250d0
--- /dev/null
+++ b/llava/model/configuration_llava.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Literal, Optional
+
+from pydantic import BaseModel, Field
+from transformers import PretrainedConfig
+
+
+class LlavaConfig(PretrainedConfig):
+ model_type = "llava"
+
+ def __init__(
+ self,
+ llm_cfg=None,
+ vision_tower_cfg=None,
+ speech_tower_cfg=None,
+ sound_tower_cfg=None,
+ mm_projector_cfg=None,
+ speech_mm_projector_cfg=None,
+ sound_mm_projector_cfg=None,
+ architectures=None,
+ resume_path=None,
+ hidden_size=None,
+ mm_hidden_size=None,
+ speech_hidden_size=None,
+ sound_hidden_size=None,
+ image_aspect_ratio=None,
+ num_video_frames=None,
+ fps=None,
+ mm_vision_select_layer=None,
+ mm_vision_select_feature=None,
+ mm_use_im_start_end=False,
+ mm_use_im_patch_token=False,
+ mm_projector_lr=None,
+ speech_mm_projector_lr=None,
+ sound_mm_projector_lr=None,
+ vision_tower_lr=None,
+ speech_tower_lr=None,
+ sound_tower_lr=None,
+ vision_resolution=None,
+ interpolate_mode=None,
+ s2=None,
+ dynamic_s2=None,
+ s2_scales=None,
+ s2_max_split_size=None,
+ s2_resize_output_to_scale_idx=0,
+ min_tiles: Optional[int] = 1,
+ max_tiles: Optional[int] = 12,
+ video_max_tiles: Optional[int] = 1,
+ num_time_tokens=None,
+ time_token_format=None,
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
+ video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
+ speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}',
+ sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}',
+ **kwargs,
+ ):
+ super().__init__()
+ self.architectures = architectures
+ self.llm_cfg = llm_cfg
+ self.vision_tower_cfg = vision_tower_cfg
+ self.speech_tower_cfg = speech_tower_cfg
+ self.sound_tower_cfg = sound_tower_cfg
+ self.mm_projector_cfg = mm_projector_cfg
+ self.speech_mm_projector_cfg = speech_mm_projector_cfg
+ self.sound_mm_projector_cfg = sound_mm_projector_cfg
+ self.resume_path = resume_path
+
+ self.hidden_size = hidden_size
+ self.mm_hidden_size = mm_hidden_size
+ self.speech_hidden_size = speech_hidden_size
+ self.sound_hidden_size = sound_hidden_size
+ self.image_aspect_ratio = image_aspect_ratio
+ self.num_video_frames = num_video_frames
+ self.fps = fps
+ self.mm_vision_select_layer = mm_vision_select_layer
+ self.mm_vision_select_feature = mm_vision_select_feature
+ self.mm_use_im_start_end = mm_use_im_start_end
+ self.mm_use_im_patch_token = mm_use_im_patch_token
+ self.mm_projector_lr = mm_projector_lr
+ self.speech_mm_projector_lr = speech_mm_projector_lr
+ self.sound_mm_projector_lr = sound_mm_projector_lr
+ self.vision_tower_lr = vision_tower_lr
+ self.speech_tower_lr = speech_tower_lr
+ self.sound_tower_lr = sound_tower_lr
+ self.vision_resolution = vision_resolution
+ self.interpolate_mode = interpolate_mode
+ self.s2 = s2
+ self.dynamic_s2 = dynamic_s2
+ self.s2_scales = s2_scales
+ self.s2_max_split_size = s2_max_split_size
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
+ self.min_tiles = min_tiles
+ self.max_tiles = max_tiles
+ self.video_max_tiles = video_max_tiles
+ self.num_time_tokens = num_time_tokens
+ self.time_token_format = time_token_format
+
+ self.image_encoder = image_encoder
+ self.video_encoder = video_encoder
+ self.speech_encoder = speech_encoder
+ self.sound_encoder = sound_encoder
+
+
+class JsonSchemaResponseFormat(BaseModel):
+ schema_: str = Field(alias="schema")
+
+
+class ResponseFormat(BaseModel):
+ type: Literal["text", "json_object", "json_schema"]
+ json_schema: Optional[JsonSchemaResponseFormat] = None
diff --git a/llava/model/deprecate_consolidate.py b/llava/model/deprecate_consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebb9619df19e4e8791920b4c22d38763538454b1
--- /dev/null
+++ b/llava/model/deprecate_consolidate.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+import argparse
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/llava/model/encoders/__init__.py b/llava/model/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e80d23d6693e11f0779e342c9b6384d78b4c79
--- /dev/null
+++ b/llava/model/encoders/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .image import *
+from .video import *
+from .speech import *
+from .sound import *
+
diff --git a/llava/model/encoders/__pycache__/__init__.cpython-310.pyc b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51821273a8633724bacf886f82c5d3f6dccef0f2
Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/llava/model/encoders/__pycache__/__init__.cpython-311.pyc b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91555cc8635606593bafd5cbfc61fff128718d6e
Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc differ
diff --git a/llava/model/encoders/__pycache__/base.cpython-310.pyc b/llava/model/encoders/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1a5e6a8761723bbebdd2cec99f1d26496e05ce0
Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-310.pyc differ
diff --git a/llava/model/encoders/__pycache__/base.cpython-311.pyc b/llava/model/encoders/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e0467e667be5f975f22ccd0d080fe75c8ea560c
Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-311.pyc differ
diff --git a/llava/model/encoders/base.py b/llava/model/encoders/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9230c954a0f711b063e9b9929cbe0e3d40b55a
--- /dev/null
+++ b/llava/model/encoders/base.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from torch import nn
+
+__all__ = ["BaseEncoder"]
+
+
+class BaseEncoder(nn.Module):
+ def __init__(self, parent: nn.Module) -> None:
+ super().__init__()
+ self._parent = [parent]
+
+ @property
+ def parent(self) -> nn.Module:
+ return self._parent[0]
diff --git a/llava/model/encoders/image/__init__.py b/llava/model/encoders/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/image/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/image/basic.py b/llava/model/encoders/image/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..6428e38f56ffe421edd25bf45f378b25849874f8
--- /dev/null
+++ b/llava/model/encoders/image/basic.py
@@ -0,0 +1,56 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicImageEncoder"]
+
+
+class BasicImageEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any], frame_times=None) -> List[torch.Tensor]:
+ images = torch.stack(images, dim=0)
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
+
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/sound/__init__.py b/llava/model/encoders/sound/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/sound/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/sound/basic.py b/llava/model/encoders/sound/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb5ba5af1ba40fc9a9cba5b41e37673a6cdfeb2
--- /dev/null
+++ b/llava/model/encoders/sound/basic.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicSoundEncoder"]
+
+
+class BasicSoundEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ features = features.to(self.parent.device)
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], masks: Dict[str, Any]) -> List[torch.Tensor]:
+ sounds = torch.stack(sounds, dim=0)
+ masks = torch.stack(masks, dim=0)
+ features = self.parent.encode_sound(sounds, masks)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/speech/__init__.py b/llava/model/encoders/speech/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/speech/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/speech/basic.py b/llava/model/encoders/speech/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..68226501c5fcfc8d0cc6030ca71b324a7cb7f360
--- /dev/null
+++ b/llava/model/encoders/speech/basic.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicSpeechEncoder"]
+
+
+class BasicSpeechEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, speeches: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ speeches = torch.stack(speeches, dim=0)
+ features = self.parent.encode_speech(speeches)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/video/__init__.py b/llava/model/encoders/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b00796ffe81fd8d1cb50cd604a353e8c4e8199ad
--- /dev/null
+++ b/llava/model/encoders/video/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
+from .tsp import *
diff --git a/llava/model/encoders/video/basic.py b/llava/model/encoders/video/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b84c6917874686d071ef60a5bd6f2cab0a7f9e4
--- /dev/null
+++ b/llava/model/encoders/video/basic.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicVideoEncoder"]
+
+
+class BasicVideoEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
+ features = torch.cat([start_embeds, features], dim=1)
+ if end_token_embeds is not None:
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
+ features = torch.cat([features, end_embeds], dim=1)
+ return features.flatten(0, 1)
+
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ num_frames = [video.shape[0] for video in videos]
+ images = torch.cat(videos, dim=0)
+ features = self.parent.encode_images(images)
+ features = torch.split(features, num_frames)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/video/tsp.py b/llava/model/encoders/video/tsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..efa06530b3aea80a994efb5bd3d5db70f46bd577
--- /dev/null
+++ b/llava/model/encoders/video/tsp.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+from .basic import BasicVideoEncoder
+
+__all__ = ["TSPVideoEncoder"]
+
+
+def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
+ return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
+
+
+class TSPVideoEncoder(BasicVideoEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ pool_sizes: List[Tuple[int, int, int]],
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ sep_tokens: Optional[str] = None,
+ ) -> None:
+ super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
+ self.pool_sizes = pool_sizes
+ self.sep_tokens = sep_tokens
+
+ def _process_features(
+ self,
+ inputs: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ sep_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ nt, ns = inputs.shape[:2]
+ nl = int(ns**0.5)
+ outputs = []
+ for pool_size in self.pool_sizes:
+ features = inputs.view(nt, nl, nl, -1)
+ for dim, p in enumerate(pool_size):
+ features = pool(features, p, dim=dim)
+ features = features.flatten(1, 2)
+ features = super()._process_features(
+ features,
+ start_token_embeds=start_token_embeds,
+ end_token_embeds=end_token_embeds,
+ )
+ if sep_token_embeds is not None:
+ features = torch.cat([features, sep_token_embeds], dim=0)
+ outputs.append(features)
+ return torch.cat(outputs, dim=0)
+
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ num_frames = [video.shape[0] for video in videos]
+ images = torch.cat(videos, dim=0)
+ features = self.parent.encode_images(images)
+ features = torch.split(features, num_frames)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ sep_token_embeds=self.embed_tokens(self.sep_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/language_model/builder.py b/llava/model/language_model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6936800c3c06666fc9d3d02f7424d0eb6409ea39
--- /dev/null
+++ b/llava/model/language_model/builder.py
@@ -0,0 +1,223 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+import os
+import os.path as osp
+import warnings
+from dataclasses import asdict
+from typing import Tuple
+
+import torch
+from huggingface_hub import file_exists, repo_exists
+from huggingface_hub.utils import HFValidationError
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForVision2Seq,
+ AutoTokenizer,
+ PretrainedConfig,
+ PreTrainedModel,
+ PreTrainedTokenizer,
+)
+
+
+from llava.constants import MEDIA_TOKENS
+from llava.model.utils import packing
+from llava.utils.logging import logger
+from llava.utils.tokenizer import infer_stop_tokens
+
+
+def has_tokenizer(repo_id_or_path: str) -> bool:
+ # Check if the tokenizer is in a local directory
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
+ return True
+
+ # Check if the tokenizer is in a Hugging Face Hub repo
+ try:
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
+ except HFValidationError:
+ return False
+
+
+def context_length_extension(config):
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
+ model_max_length = getattr(config, "model_max_length", None)
+ if orig_ctx_len and model_max_length > orig_ctx_len:
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
+ return config
+
+
+def build_llm_and_tokenizer(
+ model_name_or_path: str,
+ config: PretrainedConfig,
+ attn_implementation=None,
+ model_max_length=None,
+ *args,
+ **kwargs,
+) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+ # print(model_name_or_path)
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
+ llm_cfg._attn_implementation = attn_implementation
+ llm_cfg.model_max_length = model_max_length
+ if model_max_length is not None:
+ context_length_extension(llm_cfg)
+
+ # Quantization related
+ quantization_restore_from_checkpoint = False
+ if kwargs.get("quantize_model_class") is not None:
+ assert kwargs.get("model_args") is not None
+ quantize_model_class = kwargs.pop("quantize_model_class", None)
+ model_args = kwargs.pop("model_args", None)
+
+ if quantize_model_class == "QLlamaForCausalLM": # TODO: Also change the name of this class
+ from .qllama import QLlamaConfig
+
+ llm_cfg.architectures = "QLlamaForCausalLM"
+ _attn_implementation = llm_cfg._attn_implementation
+ llm_cfg = QLlamaConfig(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+ elif quantize_model_class == "QMemLlamaForCausalLM": # TODO: Also change the name of this class
+ from .qmemllama import QMemLlamaConfig
+
+ llm_cfg.architectures = "QMemLlamaForCausalLM"
+ llm_cfg = QMemLlamaConfig(**llm_cfg.to_dict())
+ elif quantize_model_class == "FP8LinearQwen2ForCausalLM":
+ from .configuration_quantize import QuantizationConfig
+ from .fp8linearqwen2 import FP8LinearQwen2Config
+
+ llm_cfg.architectures = "FP8LinearQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8LinearQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+
+ elif quantize_model_class == "FP8ActivationQwen2ForCausalLM":
+ from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+ from .fp8activationqwen2 import FP8ActivationQwen2Config
+
+ quantization_restore_from_checkpoint = True
+
+ llm_cfg.architectures = "FP8ActivationQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8ActivationQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+
+ elif quantize_model_class == "FP8ActivationResidualQwen2ForCausalLM":
+ from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+ from .fp8activationresidualqwen2 import FP8ActivationResidualQwen2Config
+
+ quantization_restore_from_checkpoint = True
+
+ llm_cfg.architectures = "FP8ActivationResidualQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8ActivationResidualQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+ else:
+ raise ValueError(f"{quantize_model_class} is not supported quantize_model_class.")
+
+ kwargs.pop("quantize_model_class", None)
+
+ if quantize_model_class in [
+ "FP8LinearQwen2ForCausalLM",
+ "FP8ActivationQwen2ForCausalLM",
+ "FP8ActivationResidualQwen2ForCausalLM",
+ ]: # Remove the quantization args from llm_cfg and make it a independent config
+ llm_cfg.update(model_args_dict)
+ else:
+ llm_cfg.update(asdict(model_args))
+ # print(model_args)
+
+ if quantization_restore_from_checkpoint:
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
+
+ llm = AutoModelForCausalLM.from_pretrained(
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
+ )
+
+ else:
+ llm = AutoModelForCausalLM.from_pretrained(
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
+ )
+ packing.patch(llm)
+
+ # Locate the tokenizer.
+ llm_path = model_name_or_path
+ if not has_tokenizer(llm_path):
+ llm_path = osp.join(llm_path, "llm")
+ if not has_tokenizer(llm_path):
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
+
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
+ if model_max_length is not None:
+ tokenizer.model_max_length = model_max_length
+
+ # Load chat template if specified.
+ if getattr(config, "chat_template", None) is not None:
+ logger.info(f"Using chat template: {config.chat_template}")
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
+ with open(fpath) as fd:
+ chat_template = fd.read()
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
+
+ # Set stop tokens for the tokenizer
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
+
+ # Add media tokens to the tokenizer
+ tokenizer.media_tokens = MEDIA_TOKENS
+ tokenizer.media_token_ids = {}
+ for name, token in MEDIA_TOKENS.items():
+ tokenizer.add_tokens([token], special_tokens=True)
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
+
+ # TODO(ligeng): is this necessary for llava?
+ config.hidden_size = llm.config.hidden_size
+ return llm, tokenizer
diff --git a/llava/model/language_model/chat_templates/mistral.jinja b/llava/model/language_model/chat_templates/mistral.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..774073e68db8dc9ffc65daab8d948c0079235a5d
--- /dev/null
+++ b/llava/model/language_model/chat_templates/mistral.jinja
@@ -0,0 +1,11 @@
+{{ bos_token }}
+
+{% for message in messages if message['content'] is not none %}
+ {% if message['role'] == 'system' %}
+ {{ message['content'] | trim + '\n\n' }}
+ {% elif message['role'] == 'user' %}
+ {{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
+ {% elif message['role'] == 'assistant' %}
+ {{ ' ' + message['content'] | trim + eos_token }}
+ {% endif %}
+{% endfor %}
diff --git a/llava/model/language_model/chat_templates/qwen2.jinja b/llava/model/language_model/chat_templates/qwen2.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..d3b0fd8c09e426c3a8c6e29edde3e56586c961bb
--- /dev/null
+++ b/llava/model/language_model/chat_templates/qwen2.jinja
@@ -0,0 +1,11 @@
+{% if messages[0]['role'] != 'system' %}
+ {{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
+{% endif %}
+
+{% for message in messages if message['content'] is not none %}
+ {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
+{% endfor %}
+
+{% if add_generation_prompt %}
+ {{ '<|im_start|>assistant\n' }}
+{% endif %}
diff --git a/llava/model/language_model/configuration_quantize.py b/llava/model/language_model/configuration_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b226923e1ab06e6f868455e72dd8430386c8ac8
--- /dev/null
+++ b/llava/model/language_model/configuration_quantize.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+
+from transformers import PretrainedConfig
+
+
+@dataclass
+class QuantizationConfig:
+ quantize_model: str = "false"
+ symm: bool = True
+ epsilon: float = 1e-10
+ fabit: str = "E4M3"
+ fwbit: str = "E4M3"
+ bobit: str = "E5M2"
+ row_blocksize: int = -1
+ col_blocksize: int = -1
+ qchoice: str = "none"
+ pad_to_multiple_of: int = 0
+
+ def __init__(
+ self,
+ quantize_model,
+ symm,
+ epsilon,
+ fabit,
+ fwbit,
+ bobit,
+ row_blocksize,
+ col_blocksize,
+ qchoice,
+ pad_to_multiple_of,
+ **kwargs,
+ ):
+ super().__init__()
+ self.quantize_model = quantize_model
+ self.symm = symm
+ self.epsilon = epsilon
+ self.fabit = fabit
+ self.fwbit = fwbit
+ self.bobit = bobit
+ self.row_blocksize = row_blocksize
+ self.col_blocksize = col_blocksize
+ self.qchoice = qchoice
+ self.pad_to_multiple_of = pad_to_multiple_of
+
+
+# class QuantizationConfig(PretrainedConfig):
+# def __init__(
+# self,
+# quantize_model="false",
+# symm=True,
+# epsilon=1e-10,
+# fabit="E4M3",
+# fwbit="E4M3",
+# bobit="E5M2",
+# row_blocksize=-1,
+# col_blocksize=-1,
+# qchoice="none",
+# pad_to_multiple_of=0,
+# **kwargs,
+# ):
+# super().__init__()
+# self.quantize_model = quantize_model
+# self.symm = symm
+# self.epsilon = epsilon
+# self.fabit = fabit
+# self.fwbit = fwbit
+# self.bobit = bobit
+# self.row_blocksize = row_blocksize
+# self.col_blocksize = col_blocksize
+# self.qchoice = qchoice
+# self.pad_to_multiple_of = pad_to_multiple_of
diff --git a/llava/model/language_model/fp8_qwen2_convert_from_hf.py b/llava/model/language_model/fp8_qwen2_convert_from_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a7f9f602491851c21cf7a7702a7174b414bfee
--- /dev/null
+++ b/llava/model/language_model/fp8_qwen2_convert_from_hf.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import os
+from dataclasses import asdict, dataclass, field
+from typing import Optional
+
+import torch
+import transformers
+from transformers import AutoConfig, AutoModelForCausalLM
+
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from .fp8activationqwen2 import FP8ActivationQwen2Config, make_state_dict_compatible
+
+
+@dataclass
+class ConvertArguments:
+ model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"})
+ save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"})
+ cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"})
+
+
+def download_and_convert_qwen2(convert_args: ConvertArguments, quantization_args: QuantizationConfig):
+ """
+ Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`,
+ and saves the converted model.
+
+ Args:
+ model_name (str): The model name or path to download the LLaMA model.
+ save_path (str): The path where the converted model weights will be saved.
+ cache_dir (Optional[str]): Directory to cache the model. Defaults to None.
+
+ Returns:
+ None
+ """
+ model_name = convert_args.model_name
+ save_path = convert_args.save_path
+ cache_dir = convert_args.cache_dir
+
+ # Step 1: Download the original LLaMA model
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 2: Initialize the model configuration for FP8 or other custom config
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 3: Apply make_state_dict_compatible to convert weights
+ compatible_state_dict = make_state_dict_compatible(model.state_dict())
+
+ # Step 4: Create a new model instance with compatible configuration
+ fp8_config = FP8ActivationQwen2Config(**config.to_dict())
+ fp8_config.coat_fp8_args = asdict(quantization_args)
+ fp8_config._name_or_path = save_path
+
+ converted_model = AutoModelForCausalLM.from_config(fp8_config, torch_dtype=torch.bfloat16)
+ converted_model.load_state_dict(compatible_state_dict)
+
+ # Step 5: Save the converted model and configuration using save_pretrained
+ os.makedirs(save_path, exist_ok=True)
+ converted_model.save_pretrained(save_path)
+ print(f"Converted model saved at {save_path}")
+
+
+if __name__ == "__main__":
+ # Parse command-line arguments
+ parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8
+ convert_args, quantization_args = parser.parse_args_into_dataclasses()
+
+ # Call the function with parsed arguments
+ download_and_convert_qwen2(convert_args, quantization_args)
diff --git a/llava/model/language_model/fp8activationqwen2.py b/llava/model/language_model/fp8activationqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..34bde452da53042dc55b39ae204e541e1eb1f920
--- /dev/null
+++ b/llava/model/language_model/fp8activationqwen2.py
@@ -0,0 +1,1722 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+import os
+from dataclasses import asdict, dataclass, field
+from fnmatch import fnmatch
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+# FP8 related
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule
+from ..coat.activation.models._fp8manager import FP8Manager
+from ..coat.activation.real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ..liger.cross_entropy import LigerForCausalLMLoss
+from ..qlinear_te import QLinearTE
+
+# from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8ActivationQwen2Config(Qwen2Config):
+ model_type = "fp8activation_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+class FP8ActivationQwen2BeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ # Directly use the corresponding weight
+ weight1, _, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, _, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, _, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+
+ weight1, weight2, weight3 = None, None, None
+ return _FP8ActivationQwen2BeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ None,
+ weight1_s,
+ self.q_proj.bias,
+ self.k_proj.weight,
+ weight2,
+ None,
+ weight2_s,
+ self.k_proj.bias,
+ self.v_proj.weight,
+ weight3,
+ None,
+ weight3_s,
+ self.v_proj.bias,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _FP8ActivationQwen2BeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("This should be implemented in the future")
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _FP8ActivationQwen2BeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight1_bias,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight2_bias,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ weight3_bias,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert (
+ weight1_t is None and weight2_t is None and weight3_t is None
+ ) # we should not pass W^T to here if weight memory efficient
+ if weight1 is None:
+ assert weight2 is None and weight3 is None
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+ ctx.bias = weight1_bias, weight2_bias, weight3_bias
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+ weight1_bias, weight2_bias, weight3_bias = ctx.bias
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Gradient of Bias TODO: make this better
+ if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None:
+ att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0)
+ att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0)
+ att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0)
+ else:
+ att_q_bg = None
+ att_k_bg = None
+ att_v_bg = None
+
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_q_bg,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_k_bg,
+ att_v_wg,
+ None,
+ None,
+ None,
+ att_v_bg,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationQwen2AfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ weight4, _, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+ weight1 = None
+
+ return _FP8ActivationQwen2AfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationQwen2AfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _FP8ActivationQwen2AfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ if weight4 is None: # the second batch
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size)
+
+ if time_bench:
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if int(os.environ.get("LOCAL_RANK")) == 0:
+ print(
+ f"[AfterAt] Part 1: {start_1.elapsed_time(end_1):.6f} ms | "
+ f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | "
+ f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | "
+ f" Input shape: {re_x.shape}"
+ )
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class FP8ActivationQwen2MLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ # Directly use the corresponding weight
+ weight1, _, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, _, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, _, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ weight1, weight2, weight3 = None, None, None
+ return _FP8ActivationQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _FP8ActivationQwen2MLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None # memory efficient
+ if weight1 is None:
+ assert weight2 is None and weight3 is None
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # Element-wise Multiplication gradient checkpointing
+ # silu gradient checkpointing
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationQwen2AttentionWithoutLinear(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationQwen2FlashAttention2WithoutLinear(FP8ActivationQwen2AttentionWithoutLinear):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationQwen2SdpaAttentionWithoutLinear(FP8ActivationQwen2AttentionWithoutLinear):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ return attn_output, None, past_key_value
+
+
+FP8LINEARQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8ActivationQwen2AttentionWithoutLinear,
+ "flash_attention_2": FP8ActivationQwen2FlashAttention2WithoutLinear,
+ "sdpa": FP8ActivationQwen2SdpaAttentionWithoutLinear,
+}
+
+
+class FP8ActivationQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8ActivationQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = FP8ActivationQwen2BeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = FP8ActivationQwen2AfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = FP8ActivationQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual, query_states, key_states, value_states = self.BeforeAttention(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states)
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ if time_bench:
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if int(os.environ.get("LOCAL_RANK")) == 0:
+ print(
+ f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | "
+ f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | "
+ f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | "
+ f" Input shape: {hidden_states.shape}"
+ )
+
+ outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),)
+
+ # if int(os.environ.get("LOCAL_RANK")) == 0:
+ # import IPython
+ # IPython.embed()
+ # else:
+ # import time
+ # time.sleep(1000)
+ # if output_attentions:
+ # outputs += (self_attn_weights,)
+
+ # if use_cache:
+ # outputs += (present_key_value,)
+
+ return outputs
+
+
+class FP8ActivationQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8ActivationQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8ActivationQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8ActivationQwen2Model(FP8ActivationQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8ActivationQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ self.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ # import warnings
+
+ # # ๅฐๆๆ UserWarning ๆๅไธบๅผๅธธ
+ # warnings.simplefilter("error", UserWarning)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # if int(os.environ.get("LOCAL_RANK")) == 0:
+ # print(f"FP8 Layer Idx: {layer_idx}, Input Shape: {hidden_states.shape}, Memory Allocated: {torch.cuda.memory_allocated() // 1024 ** 2} MB | Peak: {torch.cuda.max_memory_allocated() // 1024 ** 2} MB")
+
+ # NOTE: We explicitly force the LLM do not use Gradient Checkpointing
+ # exit(0)
+ # if self.gradient_checkpointing and self.training:
+ # layer_outputs = self._gradient_checkpointing_func(
+ # decoder_layer.__call__,
+ # hidden_states,
+ # quant_hidden_states,
+ # scale_hidden_states,
+ # causal_mask,
+ # position_ids,
+ # past_key_values,
+ # output_attentions,
+ # use_cache,
+ # cache_position,
+ # position_embeddings,
+ # )
+ # else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8ActivationQwen2ForCausalLM(FP8ActivationQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8ActivationQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @property
+ @lru_cache
+ def loss_function(self):
+ return LigerForCausalLMLoss
+
+ forward = Qwen2ForCausalLM.forward
+
+
+AutoConfig.register("fp8activation_qwen2", FP8ActivationQwen2Config)
+AutoModel.register(FP8ActivationQwen2Config, FP8ActivationQwen2Model)
+AutoModelForCausalLM.register(FP8ActivationQwen2Config, FP8ActivationQwen2ForCausalLM)
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
diff --git a/llava/model/language_model/fp8activationresidualqwen2.py b/llava/model/language_model/fp8activationresidualqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bccf84346e0e2b59ed6dc56ca1a6799bc11c5c5
--- /dev/null
+++ b/llava/model/language_model/fp8activationresidualqwen2.py
@@ -0,0 +1,1586 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+import os
+from dataclasses import asdict, dataclass, field
+from fnmatch import fnmatch
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+# FP8 related
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule
+from ..coat.activation.models._fp8manager import FP8Manager
+from ..coat.activation.real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ..qlinear_te import QLinearTE
+from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8ActivationResidualQwen2Config(Qwen2Config):
+ model_type = "fp8activationresidual_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+class FP8ActivationResidualQwen2BeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(
+ self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None
+ ):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+
+ def forward(self, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+ return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply(
+ x,
+ s,
+ self.q_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.q_proj.bias,
+ self.k_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.k_proj.bias,
+ self.v_proj.weight,
+ None,
+ None,
+ weight3_s,
+ self.v_proj.bias,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply(
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("This should be implemented in the future")
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _FP8ActivationResidualQwen2BeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight1_bias,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight2_bias,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ weight3_bias,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+ ctx.bias = weight1_bias, weight2_bias, weight3_bias
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return in_x, in_s, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, q_grad, s_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+ weight1_bias, weight2_bias, weight3_bias = ctx.bias
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Gradient of Bias TODO: make this better
+ if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None:
+ att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0)
+ att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0)
+ att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0)
+ else:
+ att_q_bg = None
+ att_k_bg = None
+ att_v_bg = None
+
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ in_g, in_sg, in_sg_g16 = fp8_add_Ig16_Ifp_Opt(
+ q_grad, s_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_q_bg,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_k_bg,
+ att_v_wg,
+ None,
+ None,
+ None,
+ att_v_bg,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationResidualQwen2AfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(self, re_qx, re_sx, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+
+ return _FP8ActivationResidualQwen2AfterAttentionResidual.apply(
+ re_qx,
+ re_sx,
+ in_x,
+ self.o_proj.weight,
+ None,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationResidualQwen2AfterAttentionResidual.apply(
+ re_qx,
+ re_sx,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _FP8ActivationResidualQwen2AfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_qx,
+ re_sx,
+ flash_x,
+ weight4_origin,
+ weight4,
+ weight4_t,
+ weight4_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4 is None # memory efficient
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ig16_Ifp_Ofp_Og16(re_qx, re_sx, fc4_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class FP8ActivationResidualQwen2MLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch)
+
+ return _FP8ActivationResidualQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationResidualQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _FP8ActivationResidualQwen2MLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None and weight2 is None and weight3 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ out_x, out_s = fp8_add_Ifp_Ifp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return out_x, out_s
+
+ @staticmethod
+ def backward(ctx, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ out_g, out_gs_max, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationResidualQwen2AttentionWithoutLinear(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationResidualQwen2FlashAttention2WithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationResidualQwen2SdpaAttentionWithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ return attn_output, None, past_key_value
+
+
+FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8ActivationResidualQwen2AttentionWithoutLinear,
+ "flash_attention_2": FP8ActivationResidualQwen2FlashAttention2WithoutLinear,
+ "sdpa": FP8ActivationResidualQwen2SdpaAttentionWithoutLinear,
+}
+
+
+class FP8ActivationResidualQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8ActivationResidualQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = FP8ActivationResidualQwen2BeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = FP8ActivationResidualQwen2AfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = FP8ActivationResidualQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual_quant, residual_scale, query_states, key_states, value_states = self.BeforeAttention(
+ quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(
+ residual_quant, residual_scale, hidden_states
+ )
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ outputs = ((quant_hidden_states, scale_hidden_states),)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class FP8ActivationResidualQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8ActivationResidualQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8ActivationResidualQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8ActivationResidualQwen2Model(FP8ActivationResidualQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8ActivationResidualQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ quant_hidden_states,
+ scale_hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8ActivationResidualQwen2ForCausalLM(FP8ActivationResidualQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8ActivationResidualQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = Qwen2ForCausalLM.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+ prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation
+
+
+AutoConfig.register("fp8activationresidual_qwen2", FP8ActivationResidualQwen2Config)
+AutoModel.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2Model)
+AutoModelForCausalLM.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2ForCausalLM)
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
diff --git a/llava/model/language_model/fp8linearqwen2.py b/llava/model/language_model/fp8linearqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae93bb8d6798e9f327198abc405866e4a1527e2
--- /dev/null
+++ b/llava/model/language_model/fp8linearqwen2.py
@@ -0,0 +1,356 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..liger.cross_entropy import LigerForCausalLMLoss
+from ..qlinear_te import QLinearTE
+from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8LinearQwen2Config(Qwen2Config):
+ model_type = "fp8linear_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
+class FP8LinearQwen2MLP(Qwen2MLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ # self.gate_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.up_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.down_proj = te.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+ self.gate_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+ self.up_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+ self.down_proj = QLinearTE(
+ self.intermediate_size, self.hidden_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class FP8LinearQwen2Attention(Qwen2Attention):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: FP8LinearQwen2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+
+ self.q_proj = QLinearTE(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.k_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.v_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.o_proj = QLinearTE(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=False,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+
+ forward = Qwen2Attention.forward
+
+
+class FP8LinearQwen2FlashAttention2(FP8LinearQwen2Attention):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ forward = Qwen2FlashAttention2.forward
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
+class FP8LinearQwen2SdpaAttention(FP8LinearQwen2Attention):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ forward = Qwen2SdpaAttention.forward
+
+
+FP8LINEARQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8LinearQwen2Attention,
+ "flash_attention_2": FP8LinearQwen2FlashAttention2,
+ "sdpa": FP8LinearQwen2SdpaAttention,
+}
+
+
+class FP8LinearQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8LinearQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = FP8LinearQwen2MLP(config, layer_idx)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ forward = Qwen2DecoderLayer.forward
+
+
+class FP8LinearQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8LinearQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8LinearQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8LinearQwen2Model(FP8LinearQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8LinearQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8LinearQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ forward = Qwen2Model.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8LinearQwen2ForCausalLM(FP8LinearQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8LinearQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @property
+ @lru_cache
+ def loss_function(self):
+ return LigerForCausalLMLoss
+
+ forward = Qwen2ForCausalLM.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+ prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation
+
+
+AutoConfig.register("fp8linear_qwen2", FP8LinearQwen2Config)
+AutoModel.register(FP8LinearQwen2Config, FP8LinearQwen2Model)
+AutoModelForCausalLM.register(FP8LinearQwen2Config, FP8LinearQwen2ForCausalLM)
diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8fbe3f845abe162edd4d891cca35a752b0966da
--- /dev/null
+++ b/llava/model/language_model/llava_llama.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from llava.model.loss import soft_cross_entropy
+from llava.model.utils.packing import set_seqlens_in_batch
+from llava.train.sequence_parallel.globals import get_pg_manager
+from llava.utils.logging import logger
+
+from ...train.utils import calculate_loss_weight
+from ..configuration_llava import LlavaConfig
+from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
+
+
+class LlavaLlamaConfig(LlavaConfig):
+ model_type = "llava_llama"
+
+
+# FIXME we will follow the convention to add a new class for CausalLM in the future
+class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
+ config_class = LlavaLlamaConfig
+ main_input_name = "input_embeds"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+
+ def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.init_vlm(config=config, *args, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ *model_args,
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ use_safetensors: bool = None,
+ **kwargs,
+ ):
+ if hasattr(cls, "load_pretrained"):
+ return cls.load_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ **kwargs,
+ )
+ return super(LlavaLlamaModel).from_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ **kwargs,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
+ images: Optional[torch.FloatTensor] = None,
+ media_config: Optional[List] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ media_meta: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ packing: bool = True,
+ force_packing: bool = False,
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
+ dpo_forward: bool = False,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ self.freezed_module_patch()
+
+ if images is not None:
+ if media is not None:
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
+ logger.warning("The 'images' argument is deprecated. Please use 'media' instead.")
+ media = {"image": images}
+
+ if media_config is None:
+ media_config = defaultdict(dict)
+ if inputs_embeds is None:
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask,media_meta)
+
+ if force_packing or (packing and self.training and not dpo_forward):
+ if seqlens_in_batch is None:
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
+ set_seqlens_in_batch(seqlens_in_batch)
+
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
+ inputs_embeds, attention_mask, position_ids, labels
+ )
+
+ outputs = self.llm(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ labels=labels,
+ **kwargs,
+ )
+
+ if self.training and getattr(self.config, "time_token_ids", []):
+ outputs.loss = soft_cross_entropy(
+ outputs.logits,
+ labels,
+ soft_tokens=self.config.time_token_ids,
+ std=self.config.soft_ce_std,
+ )
+
+ # Loss rescale for SP
+ if get_pg_manager() is not None:
+ loss_weight = calculate_loss_weight(labels)
+ outputs.loss = outputs.loss * loss_weight
+
+ if dpo_forward:
+ return outputs.logits, labels
+
+ return outputs
+
+
+AutoConfig.register("llava_llama", LlavaLlamaConfig)
+AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
diff --git a/llava/model/language_model/qllama.py b/llava/model/language_model/qllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..db823b94634ec5cfeb589319e947f64a473f3a30
--- /dev/null
+++ b/llava/model/language_model/qllama.py
@@ -0,0 +1,312 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+from ..qfunction import *
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QLlamaConfig"
+
+
+class QLlamaConfig(LlamaConfig):
+ model_type = "qllama"
+
+
+class QLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.gate_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx
+ )
+ self.up_proj = QLinearTE(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx)
+ self.down_proj = QLinearTE(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_idx=layer_idx
+ )
+
+
+class QLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ # self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+
+ self.q_proj = QLinearTE(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.k_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.v_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.o_proj = QLinearTE(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+
+
+class QLlamaFlashAttention2(QLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ forward = LlamaFlashAttention2.forward
+
+
+class QLlamaSdpaAttention(QLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ forward = LlamaSdpaAttention.forward
+
+
+QLLAMA_ATTENTION_CLASSES = {
+ "eager": QLlamaAttention,
+ "flash_attention_2": QLlamaFlashAttention2,
+ "sdpa": QLlamaSdpaAttention,
+}
+
+
+class QLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QLlamaMLP(config, layer_idx)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+ forward = LlamaDecoderLayer.forward
+
+
+class QLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QLlamaModel(QLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QLlamaConfig
+ """
+
+ def __init__(self, config: QLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QLlamaForCausalLM(QLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QLlamaForSequenceClassification(QLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qllama", QLlamaConfig)
+AutoModel.register(QLlamaConfig, QLlamaModel)
+AutoModelForCausalLM.register(QLlamaConfig, QLlamaForCausalLM)
diff --git a/llava/model/language_model/qllava_qllama.py b/llava/model/language_model/qllava_qllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fe4f0785352cc1d9fef4b20afdfb6183672fc32
--- /dev/null
+++ b/llava/model/language_model/qllava_qllama.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+import inspect
+import os
+import os.path as osp
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ GenerationConfig,
+ LlamaConfig,
+ LlamaForCausalLM,
+ PretrainedConfig,
+ PreTrainedModel,
+)
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_utils import ContextManagers, no_init_weights
+
+from ..configuration_llava import LlavaConfig
+from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
+from ..utils import get_model_config, get_model_config_fp8
+from .builder import build_llm_and_tokenizer
+from .llava_llama import LlavaLlamaConfig, LlavaLlamaModel
+
+quantize_args_to_model_class = {
+ "fp8Linear_llama": "QLlamaForCausalLM",
+ "fp8LinearAndActivation_llama": "QMemLlamaForCausalLM",
+ "fp8Linear_qwen2": "FP8LinearQwen2ForCausalLM",
+ "fp8Activation_qwen2": "FP8ActivationQwen2ForCausalLM",
+ "fp8ActivationResidual_qwen2": "FP8ActivationResidualQwen2ForCausalLM",
+}
+
+
+class QLlavaLlamaConfig(LlavaLlamaConfig):
+ model_type = "qllava_qllama"
+
+
+## FIXME we will follow the convention to add a new class for CausalLM in the future
+class QLlavaLlamaModel(LlavaLlamaModel):
+ config_class = QLlavaLlamaConfig
+ main_input_name = "input_embeds"
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config: QLlavaLlamaConfig = None, model_args=None, *args, **kwargs) -> None:
+ PreTrainedModel.__init__(self, config)
+ return self.init_vlm(config=config, model_args=model_args, *args, **kwargs)
+
+ # rewrite to support QLlama
+ def init_vlm(self, config: PreTrainedModel = None, model_args=None, *args, **kwargs):
+ # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
+ if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"):
+ # already initialized, skipped
+ return
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]:
+ cfgs = get_model_config_fp8(config) # The first cfg is fp8
+ else:
+ cfgs = get_model_config(config)
+ if len(cfgs) == 3:
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
+ elif len(cfgs) == 4:
+ llm_cfg, vision_tower_cfg, mm_projector_cfg, fp8_llm_cfg = cfgs
+ kwargs.update({"fp8_llm_cfg": fp8_llm_cfg})
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
+
+ kwargs.update(
+ {
+ "quantize_model_class": quantize_args_to_model_class[model_args.quantize_model],
+ "model_args": model_args,
+ }
+ )
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+
+
+ for name, module in self.llm.named_modules():
+ module.layer_name = name
+
+ self.pad_to_multiple_of = model_args.pad_to_multiple_of
+
+ self.post_config()
+ self.is_loaded = True
+
+ assert (
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
+ ), "At least one of the components must be instantiated."
+
+
+AutoConfig.register("qllava_qllama", QLlavaLlamaConfig)
+AutoModel.register(QLlavaLlamaConfig, QLlavaLlamaModel)
diff --git a/llava/model/language_model/qmemllama.py b/llava/model/language_model/qmemllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e794840df2230013a66f643745e9f50ffbeef9
--- /dev/null
+++ b/llava/model/language_model/qmemllama.py
@@ -0,0 +1,679 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+
+from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QMemLlamaConfig"
+
+
+class QMemLlamaConfig(LlamaConfig):
+ model_type = "qmemllama"
+
+
+class QLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.qargs = args
+ self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, hidden_states, s):
+ hidden_states = self.QAct_layernorm_in(hidden_states, s)
+
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+
+ hidden_states, s = self.QAct_layernorm_out(hidden_states)
+ return hidden_states, s
+
+
+ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm)
+
+
+class QMemLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ self.gate_proj = QLinear(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate"
+ )
+ self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up")
+ self.down_proj = QLinear(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down"
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum")
+ self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate")
+
+ self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up")
+
+ self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in")
+ self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out")
+
+ self.QMul_act = QMul(config, layer_type="mul_act")
+
+ def forward(self, x, s):
+ if self.config.pretraining_tp > 1:
+ raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity")
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ x = self.QAct_act_sum(x, s)
+ x_gate, s_gate = self.QAct_act_gate(x)
+ x_up, s_up = self.QAct_act_up(x)
+ x_gate, s_gate = self.gate_proj(x_gate, s_gate)
+ x_gate = self.QAct_act_in(x_gate, s_gate)
+ x_gate = self.act_fn(x_gate)
+ x_gate, s_gate = self.QAct_act_out(x_gate)
+
+ x_up, s_up = self.up_proj(x_up, s_up)
+ x, s = self.QMul_act(x_gate, x_up, s_gate, s_up)
+ down_proj, s = self.down_proj(x, s)
+ return down_proj, s
+
+
+class QMemLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ self.q_proj = QLinear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_q",
+ )
+ self.k_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_k",
+ )
+ self.v_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_v",
+ )
+ self.o_proj = QLinear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_proj",
+ )
+
+ self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum")
+
+ self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in")
+ self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in")
+ self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in")
+
+ self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out")
+ self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out")
+ self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out")
+ self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in")
+
+
+class QMemLlamaFlashAttention2(QMemLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, s, attn_weights, past_key_value
+
+
+class QMemLlamaSdpaAttention(QMemLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ # attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ return attn_output, s, None, past_key_value
+
+
+QMemLLAMA_ATTENTION_CLASSES = {
+ "eager": QMemLlamaAttention,
+ "flash_attention_2": QMemLlamaFlashAttention2,
+ "sdpa": QMemLlamaSdpaAttention,
+}
+
+
+class QMemLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QMemLlamaMLP(config, layer_idx)
+
+ self.input_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn"
+ )
+ self.post_attention_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp"
+ )
+
+ self.QAdd_attn = QAdd(config, layer_type="add_attn")
+ self.QAdd_mlp = QAdd(config, layer_type="add_mlp")
+
+ self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx")
+ self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re")
+
+ self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx")
+ self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual, res = self.QAct_reattnout_re(hidden_states)
+ hidden_states, s = self.QAct_reattnout_fx(hidden_states)
+
+ hidden_states, s = self.input_layernorm(hidden_states, s)
+
+ # Self Attention
+ hidden_states, s, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ s=s,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.QAdd_attn(residual, hidden_states, res, s)
+
+ # Fully Connected
+ residual, res = self.QAct_remlpout_re(hidden_states)
+ hidden_states, s = self.QAct_remlpout_fx(hidden_states)
+
+ hidden_states, s = self.post_attention_layernorm(hidden_states, s)
+ hidden_states, s = self.mlp(hidden_states, s)
+ hidden_states = self.QAdd_mlp(residual, hidden_states, res, s)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class QMemLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QMemLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QMemLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QMemLlamaModel(QMemLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QMemLlamaConfig
+ """
+
+ def __init__(self, config: QMemLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QMemLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QMemLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qmemllama", QMemLlamaConfig)
+AutoModel.register(QMemLlamaConfig, QMemLlamaModel)
+AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM)
diff --git a/llava/model/language_model/realqmemllama.py b/llava/model/language_model/realqmemllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eb71d015f8f7a0dcffc466b4093f67970537335
--- /dev/null
+++ b/llava/model/language_model/realqmemllama.py
@@ -0,0 +1,674 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QMemLlamaConfig"
+
+
+class QMemLlamaConfig(LlamaConfig):
+ model_type = "qmemllama"
+
+
+class QLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.qargs = args
+ self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, hidden_states, s):
+ hidden_states = self.QAct_layernorm_in(hidden_states, s)
+
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+
+ hidden_states, s = self.QAct_layernorm_out(hidden_states)
+ return hidden_states, s
+
+
+ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm)
+
+
+class QMemLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.gate_proj = QLinear(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate"
+ )
+ self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up")
+ self.down_proj = QLinear(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down"
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum")
+ self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate")
+ self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up")
+
+ self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in")
+ self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out")
+
+ self.QMul_act = QMul(config, layer_type="mul_act")
+
+ def forward(self, x, s):
+ if self.config.pretraining_tp > 1:
+ raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity")
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ x = self.QAct_act_sum(x, s)
+ x_gate, s_gate = self.QAct_act_gate(x)
+ x_up, s_up = self.QAct_act_up(x)
+ x_gate, s_gate = self.gate_proj(x_gate, s_gate)
+ x_gate = self.QAct_act_in(x_gate, s_gate)
+ x_gate = self.act_fn(x_gate)
+ x_gate, s_gate = self.QAct_act_out(x_gate)
+
+ x_up, s_up = self.up_proj(x_up, s_up)
+ x, s = self.QMul_act(x_gate, x_up, s_gate, s_up)
+ down_proj, s = self.down_proj(x, s)
+ return down_proj, s
+
+
+class QMemLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.q_proj = QLinear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_q",
+ )
+ self.k_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_k",
+ )
+ self.v_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_v",
+ )
+ self.o_proj = QLinear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_proj",
+ )
+
+ self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum")
+
+ self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in")
+ self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in")
+ self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in")
+
+ self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out")
+ self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out")
+ self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out")
+ self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in")
+
+
+class QMemLlamaFlashAttention2(QMemLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, s, attn_weights, past_key_value
+
+
+class QMemLlamaSdpaAttention(QMemLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ # attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ return attn_output, s, None, past_key_value
+
+
+QMemLLAMA_ATTENTION_CLASSES = {
+ "eager": QMemLlamaAttention,
+ "flash_attention_2": QMemLlamaFlashAttention2,
+ "sdpa": QMemLlamaSdpaAttention,
+}
+
+
+class QMemLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QMemLlamaMLP(config, layer_idx)
+ self.input_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn"
+ )
+ self.post_attention_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp"
+ )
+
+ self.QAdd_attn = QAdd(config, layer_type="add_attn")
+ self.QAdd_mlp = QAdd(config, layer_type="add_mlp")
+
+ self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx")
+ self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re")
+
+ self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx")
+ self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual, res = self.QAct_reattnout_re(hidden_states)
+ hidden_states, s = self.QAct_reattnout_fx(hidden_states)
+
+ hidden_states, s = self.input_layernorm(hidden_states, s)
+
+ # Self Attention
+ hidden_states, s, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ s=s,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.QAdd_attn(residual, hidden_states, res, s)
+
+ # Fully Connected
+ residual, res = self.QAct_remlpout_re(hidden_states)
+ hidden_states, s = self.QAct_remlpout_fx(hidden_states)
+
+ hidden_states, s = self.post_attention_layernorm(hidden_states, s)
+ hidden_states, s = self.mlp(hidden_states, s)
+ hidden_states = self.QAdd_mlp(residual, hidden_states, res, s)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class QMemLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QMemLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QMemLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QMemLlamaModel(QMemLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QMemLlamaConfig
+ """
+
+ def __init__(self, config: QMemLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QMemLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QMemLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qmemllama", QMemLlamaConfig)
+AutoModel.register(QMemLlamaConfig, QMemLlamaModel)
+AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM)
diff --git a/llava/model/liger/cross_entropy.py b/llava/model/liger/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..06274f332b62ac2178904439a18c3474693242a2
--- /dev/null
+++ b/llava/model/liger/cross_entropy.py
@@ -0,0 +1,417 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import operator
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+from .utils import compare_version, element_mul_kernel, is_hip
+
+if compare_version("triton", operator.ge, "3.0.0"):
+ try:
+ # typical import path with dispatch available
+ from triton.language.extra.libdevice import tanh
+ except ModuleNotFoundError:
+ # for working with NGC containers
+ from triton.language.extra.cuda.libdevice import tanh
+else:
+ from triton.language.math import tanh
+
+_TRUE = tl.constexpr(1)
+_FALSE = tl.constexpr(0)
+
+
+@triton.jit
+def liger_cross_entropy_kernel(
+ X_ptr,
+ X_stride,
+ Y_ptr,
+ Y_stride,
+ loss_ptr,
+ z_loss_ptr,
+ loss_stride,
+ n_cols,
+ n_non_ignore,
+ ignore_index,
+ lse_square_scale: tl.constexpr,
+ label_smoothing: tl.constexpr,
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
+ softcap,
+ RETURN_Z_LOSS: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ HAS_SOFTCAPPING: tl.constexpr,
+):
+ """
+ This kernel computes both cross entropy loss and the gradient of the input.
+ We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
+
+ Parameters:
+ X_ptr: Pointer to input tensor.
+ X_stride (int): The stride of the input tensor.
+ Y_ptr: Pointer to target tensor.
+ Y_stride (int): The stride of the target tensor.
+ loss_ptr: Pointer to tensor to store the loss.
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
+ loss_stride (int): The stride of the loss tensor.
+ n_cols (int): The number of columns in the input tensor.
+ n_non_ignore (int): The number of non-ignored elements in the batch.
+ ignore_index (int): The index to ignore in the target.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
+ reduction (str): The string for the reduction to apply
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
+ BLOCK_SIZE (int): The block size for Triton operations.
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
+ """
+
+ # https://github.com/triton-lang/triton/issues/1058
+ # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
+ program_id = tl.program_id(0).to(tl.int64)
+
+ # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
+ Y_ptr += program_id * Y_stride
+ y = tl.load(Y_ptr)
+
+ # 2. locate the start index
+ X_ptr += program_id * X_stride
+
+ if y == ignore_index:
+ # set all X_ptr as 0
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
+ return
+
+ loss_ptr += program_id * loss_stride
+ z_loss_ptr += program_id * loss_stride
+
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
+
+ # 3. [Online softmax] first pass: find max + sum
+ m = float("-inf") # m is the max value. use the notation from the paper
+ d = 0.0 # d is the sum. use the notation from the paper
+ ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation
+ if HAS_SOFTCAPPING:
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
+
+ # Label smoothing is a general case of normal cross entropy
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
+ scaled_x_sum = 0.0
+ eps = label_smoothing / n_cols
+
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
+ if HAS_SOFTCAPPING:
+ X_block = softcap * tanh(X_block / softcap)
+ block_max = tl.max(X_block)
+ if label_smoothing > 0:
+ # scale X beforehand to avoid overflow
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
+ m_new = tl.maximum(m, block_max)
+ d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
+ m = m_new
+
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
+ lse = m + tl.log(d)
+
+ # 4. [Online Softmax] Second pass: compute gradients
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
+ # dx_y = (softmax(x_y) - 1) / N
+ # dx_i = softmax(x_i) / N, i != y
+ # For label smoothing:
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
+ # = dx_i - (1 - label_smoothing) / N
+ # With Z loss:
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
+ # dx_y = dx_i - (1 - label_smoothing) / N
+ # For 'sum' reduction, no normalization is applied:
+ # dx_y = softmax(x_y) - 1
+ # dx_i = softmax(x_i), for i โ y
+
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
+ if HAS_SOFTCAPPING:
+ intermediate = tanh(X_block / softcap)
+ X_block = softcap * intermediate
+ # softmax(x_i)
+ X_block = tl.exp(X_block - m) / d
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
+ X_block += 2 * lse_square_scale * lse * X_block
+ # smoothing term
+ X_block += -eps
+ # special handle dx_y
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
+ # reduction scale
+ if reduction == "mean":
+ X_block = X_block / (n_non_ignore)
+ # chain rule
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
+ if HAS_SOFTCAPPING:
+ X_block = X_block * (1 - intermediate * intermediate)
+
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
+
+ # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
+ tl.debug_barrier()
+
+ # 5. Calculate the loss
+
+ # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
+ # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
+ # = X_y - m - log d = X_y - lse
+ # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
+ # So we can safely calculate log (softmax(X_y)) without overflow
+ loss = lse - ori_X_y
+
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
+ if label_smoothing > 0:
+ smooth_loss = scaled_x_sum + label_smoothing * lse
+ loss = loss * (1 - label_smoothing) + smooth_loss
+
+ # An auxiliary loss, z_loss
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
+ z_loss = lse_square_scale * lse * lse
+ loss += z_loss
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
+ if reduction == "mean":
+ z_loss = z_loss / n_non_ignore
+ loss = loss / n_non_ignore
+
+ tl.store(loss_ptr, loss)
+ if RETURN_Z_LOSS == _TRUE:
+ tl.store(z_loss_ptr, z_loss)
+
+
+# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
+# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
+# The optimal maximum block size depends on your hardware, your kernel, and your dtype
+MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
+
+
+_bool_to_return_z_loss = {
+ True: _TRUE.value,
+ False: _FALSE.value,
+}
+
+
+def cross_entropy_forward(
+ _input,
+ target,
+ ignore_index,
+ lse_square_scale,
+ label_smoothing,
+ reduction,
+ softcap,
+ return_z_loss,
+):
+ if not isinstance(return_z_loss, int):
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
+ return_z_loss = _bool_to_return_z_loss[return_z_loss]
+ else:
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
+
+ BT, V = _input.shape
+ n_rows = BT
+
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
+
+ # unreduced loss
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
+ if return_z_loss == _TRUE.value:
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
+ else:
+ z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
+
+ n_non_ignore = (target != ignore_index).sum().item()
+
+ # ensure _input and target are contiguous in the last dimension
+ if _input.stride(-1) != 1:
+ _input = _input.contiguous()
+ if target.stride(-1) != 1:
+ target = target.contiguous()
+
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
+ liger_cross_entropy_kernel[(n_rows,)](
+ X_ptr=_input,
+ X_stride=_input.stride(-2),
+ Y_ptr=target,
+ Y_stride=target.stride(-1), # always 1
+ loss_ptr=loss_1d,
+ z_loss_ptr=z_loss_1d,
+ loss_stride=loss_1d.stride(-1), # always 1
+ n_cols=V,
+ n_non_ignore=n_non_ignore,
+ ignore_index=ignore_index,
+ lse_square_scale=lse_square_scale,
+ label_smoothing=label_smoothing,
+ reduction=reduction,
+ softcap=softcap if softcap is not None else 0.0,
+ RETURN_Z_LOSS=return_z_loss,
+ BLOCK_SIZE=BLOCK_SIZE,
+ HAS_SOFTCAPPING=True if softcap is not None else False,
+ # TODO: 32 seems to give the best performance
+ # Performance is quite sensitive to num_warps
+ num_warps=32 if not is_hip() else 16,
+ )
+
+ loss = torch.sum(loss_1d)
+ if return_z_loss == _TRUE.value:
+ z_loss = torch.sum(z_loss_1d)
+ else:
+ z_loss = None
+
+ return loss, z_loss, _input
+
+
+def cross_entropy_backward(_input, grad_output):
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
+ pass
+
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
+ else:
+ BT, V = _input.shape
+ n_rows = BT
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
+
+ element_mul_kernel[(n_rows,)](
+ _input,
+ _input.stride(-2),
+ grad_output,
+ V,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=32 if not is_hip() else 16,
+ )
+
+ return _input
+
+
+class LigerCrossEntropyFunction(torch.autograd.Function):
+ """
+ This class implements a custom autograd function for the Liger Cross Entropy loss.
+ It overrides the forward and backward methods of the torch.autograd.Function class.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ _input: torch.Tensor,
+ target: torch.Tensor,
+ ignore_index: int = -100,
+ lse_square_scale: float = 0.0,
+ label_smoothing: float = 0.0,
+ reduction: str = "mean",
+ softcap: Optional[float] = None,
+ return_z_loss: bool = False,
+ ):
+ """
+ The forward pass of the Liger Cross Entropy loss.
+
+ Parameters:
+ ctx : The context object.
+ _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
+ target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
+ ignore_index (int): The index to ignore in the target.
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
+
+ Returns:
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
+ """
+ loss, z_loss, _input = cross_entropy_forward(
+ _input,
+ target,
+ ignore_index,
+ lse_square_scale,
+ label_smoothing,
+ reduction,
+ softcap,
+ return_z_loss,
+ )
+ # TODO: investigation
+ # If we don't detach the _input tensor, the memory will double
+ # Not sure why but seems that there will be a time both grad and value exist but in different location
+ ctx.save_for_backward(_input.detach())
+ ctx.return_z_loss = return_z_loss
+
+ return loss, z_loss
+
+ @staticmethod
+ def backward(ctx, grad_output, grad_ouput2):
+ """
+ The backward pass of the Liger Cross Entropy loss.
+
+ Parameters:
+ ctx : The context object with saved tensors.
+ grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
+ grad_output2 (tenosr): No use.
+ Returns:
+ tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
+ """
+ if ctx.return_z_loss:
+ del grad_ouput2 # z_loss is only for logging
+
+ (_input,) = ctx.saved_tensors
+ _input = cross_entropy_backward(_input, grad_output)
+ return (
+ _input,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def liger_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
+ reduction = "sum" if num_items_in_batch is not None else "mean"
+ # loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
+ loss, _ = LigerCrossEntropyFunction.apply(source, target, ignore_index, 0.0, 0.0, reduction)
+ if reduction == "sum":
+ loss = loss / num_items_in_batch
+ return loss
+
+
+def LigerForCausalLMLoss(
+ logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
+):
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # Flatten the tokens
+ shift_logits = shift_logits.view(-1, vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = liger_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
+ return loss
diff --git a/llava/model/liger/utils.py b/llava/model/liger/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5290b98c237f43757d0ed74d91f8f3bf6a807eeb
--- /dev/null
+++ b/llava/model/liger/utils.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+"""
+This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
+See the original Unsloth repository at https://github.com/unslothai/unsloth.
+
+The following line
+https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
+is based on code from Unsloth, located at:
+https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
+
+Modifications made by Yanning Chen, 2024.
+"""
+
+import functools
+import importlib
+import operator
+from typing import Callable
+
+import torch
+import triton
+import triton.language as tl
+from packaging.version import Version
+
+
+def is_hip() -> bool:
+ return torch.version.hip is not None
+
+
+def ensure_contiguous(fn):
+ @functools.wraps(fn)
+ def wrapper(ctx, *args, **kwargs):
+ def maybe_to_contiguous(x):
+ return x.contiguous() if isinstance(x, torch.Tensor) else x
+
+ args = [maybe_to_contiguous(arg) for arg in args]
+ kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
+ return fn(ctx, *args, **kwargs)
+
+ return wrapper
+
+
+def calculate_settings(n):
+ # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
+
+ MAX_FUSED_SIZE = 65536
+ BLOCK_SIZE = triton.next_power_of_2(n)
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
+ raise RuntimeError(
+ f"Cannot launch Triton kernel since n = {n} exceeds "
+ f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
+ )
+
+ num_warps = 4
+ if BLOCK_SIZE >= 32768:
+ num_warps = 32 if not is_hip() else 16
+ elif BLOCK_SIZE >= 8192:
+ num_warps = 16
+ elif BLOCK_SIZE >= 2048:
+ num_warps = 8
+ return BLOCK_SIZE, num_warps
+
+
+def compare_version(package: str, operator: Callable, target: str):
+ try:
+ pkg = importlib.import_module(package)
+ except ImportError:
+ return False
+ pkg_version = Version(pkg.__version__)
+ return operator(pkg_version, Version(target))
+
+
+def get_amp_custom_fwd_bwd() -> Callable:
+ if compare_version("torch", operator.ge, "2.4.0"):
+ return (
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
+ )
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
+
+
+amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
+
+
+torch_to_triton_dtype = {
+ torch.float32: tl.float32,
+ torch.float16: tl.float16,
+ torch.bfloat16: tl.bfloat16,
+}
+
+
+@triton.jit
+def element_mul_kernel(
+ X_ptr,
+ X_stride,
+ grad_output_ptr,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
+
+ Parameters:
+ X_ptr: Pointer to the input tensor.
+ X_stride (int): The stride of the input tensor.
+ grad_output_ptr: Pointer to the gradient output value.
+ n_cols (int): The number of columns in the input tensor.
+ BLOCK_SIZE (int): The block size for Triton operations.
+ """
+
+ # Get the program ID and convert it to int64 to avoid overflow
+ program_id = tl.program_id(0).to(tl.int64)
+
+ # Locate the start index
+ X_ptr += program_id * X_stride
+
+ # Load the gradient output value
+ grad_output = tl.load(grad_output_ptr)
+
+ # Perform the element-wise multiplication
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6feb1014fbea1a6d7904b88c5d2f2a0d84607bb
--- /dev/null
+++ b/llava/model/llava_arch.py
@@ -0,0 +1,861 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import json
+import logging
+import os
+import os.path as osp
+import warnings
+from abc import ABC
+from collections import OrderedDict, defaultdict, deque
+from itertools import chain
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from einops import rearrange
+from hydra.utils import instantiate
+from transformers import AutoConfig, GenerationConfig, LogitsProcessor, PreTrainedModel
+from transformers.modeling_utils import ContextManagers, no_init_weights
+
+from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
+from llava.mm_utils import process_image, process_images, process_sounds,process_sound_masks
+from llava.model.configuration_llava import LlavaConfig, ResponseFormat
+from llava.model.language_model.builder import build_llm_and_tokenizer
+from llava.model.multimodal_encoder.builder import build_sound_tower
+from llava.model.multimodal_projector.builder import build_speech_mm_projector, build_sound_mm_projector
+from llava.model.utils import get_model_config
+from llava.train.sequence_parallel import get_pg_manager
+from llava.utils import distributed
+from llava.utils.media import extract_media
+from llava.utils.tokenizer import tokenize_conversation
+
+
+class LlavaMetaModel(ABC):
+ def _init_llm(self, llm_cfg, config, *args, **kwargs):
+ llm, tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+ return llm, tokenizer
+
+ def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
+ # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
+ if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "speech_tower") or hasattr(self, "sound_tower") or hasattr(self, "mm_projector") or hasattr(self, "speech_mm_projector") or hasattr(self, "sound_mm_projector"):
+ # already initialized, skipped
+ return
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ cfgs = get_model_config(config)
+ print(cfgs)
+ if len(cfgs) == 7:
+ llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.")
+
+ self.llm, self.tokenizer = self._init_llm(llm_cfg, config, *args, **kwargs)
+
+ self.sound_tower = build_sound_tower(sound_tower_cfg, config)
+ self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
+
+ if isinstance(self.config, dict):
+ self.vocab_size = config.llm_cfg["vocab_size"] + NUM_EXTRA_TOKENS
+ else:
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
+ logging.info(
+ f"[XGrammar] config is not a dict, loading vocab size from tokenizer {self.tokenizer.vocab_size} + {NUM_EXTRA_TOKENS} => {self.vocab_size}"
+ )
+
+ # XGrammar tokenizer and grammar compiler
+ # lazy init only when specified json output during inference
+ self.grammar_compiler = None
+
+ self.encoders = {}
+ for name in ["sound"]:
+ config = getattr(self.config, f"{name}_encoder")
+ if isinstance(config, str):
+ config = json.loads(config)
+ self.encoders[name] = instantiate(config, parent=self)
+
+ self.post_config()
+ self.is_loaded = True
+
+ assert (
+ self.llm is not None or self.vision_tower is not None or self.speech_tower is not None or self.mm_projector is not None or self.speech_mm_projector is not None
+ ), "At least one of the components must be instantiated."
+
+ @classmethod
+ def load_from_config(cls, model_path_or_config, *args, **kwargs):
+ pass
+
+ ## FIXME we will use this function to load model in the future
+ @classmethod
+ def load_pretrained(cls, model_path_or_config, *args, **kwargs):
+ kwargs.pop("config", None)
+
+ if isinstance(model_path_or_config, str):
+ config = AutoConfig.from_pretrained(model_path_or_config)
+ elif isinstance(model_path_or_config, LlavaConfig):
+ config = model_path_or_config
+ else:
+ raise NotImplementedError(
+ f"wrong type, {type(model_path_or_config)} \
+ {isinstance(model_path_or_config, LlavaConfig)}"
+ )
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ cfgs = get_model_config(config)
+ if len(cfgs) == 7:
+ llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.")
+
+ init_context = [
+ no_init_weights(_enable=True),
+ ]
+
+ with ContextManagers(init_context):
+ vlm = cls(config, *args, **kwargs)
+
+ if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "speech_tower") or hasattr(vlm, "sound_tower") or hasattr(vlm, "mm_projector") or hasattr(vlm, "speech_mm_projector") or hasattr(vlm, "sound_mm_projector"):
+ if vlm.is_loaded:
+ return vlm
+
+ vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+ vlm.sound_tower = build_sound_tower(sound_tower_cfg, config)
+ vlm.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
+
+ self.post_config()
+ self.is_loaded = True
+
+ # FIXME(ligeng, yunhao): llm should never be none here.
+ assert (
+ vlm.llm is not None or vlm.vision_tower is not None or vlm.speech_tower is not None or vlm.mm_projector is not None or vlm.speech_mm_projector is not None
+ ), "At least one of the components must be instantiated."
+ return vlm
+
+ ## FIXME we will use this function to save the model in the future
+ def save_pretrained(self, output_dir, state_dict=None):
+ if state_dict is None:
+ # other wise fetch from deepspeed
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
+ state_dict = self.state_dict()
+
+ if getattr(self, "tokenizer", None):
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
+
+ if self.get_llm():
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
+ self.config.llm_cfg = self.llm.config
+
+
+ if self.get_sound_tower():
+ print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}")
+ self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower")
+ sound_tower_state_dict = OrderedDict(
+ {k.split("sound_tower.sound_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k}
+ )
+ self.sound_tower.sound_tower.save_pretrained(
+ os.path.join(output_dir, "sound_tower"),
+ state_dict=sound_tower_state_dict,
+ )
+ self.config.sound_tower_cfg = self.sound_tower.config
+
+ if self.get_sound_mm_projector():
+ print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}")
+ self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector")
+ sound_mm_projector_state_dict = OrderedDict(
+ {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k}
+ )
+ self.sound_mm_projector.save_pretrained(
+ os.path.join(output_dir, "sound_mm_projector"),
+ state_dict=sound_mm_projector_state_dict,
+ )
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
+
+ ## update and save top-level config
+ self.config._name_or_path = output_dir
+ self.config.architectures = [self.__class__.__name__]
+ self.config.save_pretrained(output_dir)
+
+ def get_llm(self):
+ llm = getattr(self, "llm", None)
+ if type(llm) is list:
+ llm = llm[0]
+ return llm
+
+ def get_lm_head(self):
+ lm_head = getattr(self.get_llm(), "lm_head", None)
+ return lm_head
+
+ def get_sound_tower(self):
+ sound_tower = getattr(self, "sound_tower", None)
+ if type(sound_tower) is list:
+ sound_tower = sound_tower[0]
+ return sound_tower
+
+
+ def get_sound_mm_projector(self):
+ sound_mm_projector = getattr(self, "sound_mm_projector", None)
+ if type(sound_mm_projector) is list:
+ sound_mm_projector = sound_mm_projector[0]
+ return sound_mm_projector
+
+ def post_config(self):
+ self.training = self.get_llm().training
+ ## configuration
+ if getattr(self.config, "llm_cfg", None) is None:
+ self.config.llm_cfg = self.llm.config
+ self.config.speech_tower_cfg = self.speech_tower.config
+ if getattr(self.config, "sound_tower_cfg", None) is None:
+ self.config.sound_tower_cfg = self.sound_tower.config
+ if getattr(self.config, "sound_mm_projector_cfg", None) is None:
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
+
+ def freezed_module_patch(self):
+ """
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
+ """
+ if self.training:
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
+ pass
+
+ if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False):
+ self.get_sound_tower().eval()
+ if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False):
+ self.get_sound_mm_projector().eval()
+
+
+ def encode_sound(self, sounds, masks=None):
+
+ sound_features = self.get_sound_tower()(sounds, masks)
+ sound_features = self.get_sound_mm_projector()(sound_features)
+ return sound_features
+
+ ## @yunhao: is there a better way to handle function call and attributes for llm?
+ ## support beam search
+ def _temporary_reorder_cache(self, past_key_values, sorted_idx):
+ return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
+
+ def get_input_embeddings(self):
+ return self.get_llm().get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.get_llm().get_output_embeddings()
+
+ def resize_token_embeddings(self, embed_size):
+ self.get_llm().resize_token_embeddings(embed_size)
+
+
+class LlavaMetaForCausalLM(ABC):
+ def _embed(
+ self,
+ input_ids: torch.Tensor,
+ media: Dict[str, List[torch.Tensor]],
+ media_config: Dict[str, Dict[str, Any]],
+ labels: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ media_meta: Dict[str, Dict[str, Any]]= None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ for name in media:
+ self.encoders[name].end_tokens = None
+
+ # Extract text and media embeddings
+ text_embeds = self.llm.model.embed_tokens(input_ids)
+ media_embeds = self.__embed_media_tokens(media, media_config, media_meta)
+ # This is a workaround to make sure the dummy embeddings are consumed
+ while media_embeds.get("dummy"):
+ dummy_embed = media_embeds["dummy"].popleft()
+ text_embeds += torch.sum(dummy_embed) * 0
+ # Remove padding
+ batch_size = labels.shape[0]
+
+ # Build inverse mapping from token ID to media name
+ media_tokens = {}
+ for name, token_id in self.tokenizer.media_token_ids.items():
+ media_tokens[token_id] = name
+
+ # -------------------------------- #
+ num_audio_tokens = torch.stack(media_meta["sound_embed_masks"], dim=0).sum(-1)
+ num_audio_tokens = torch.tensor([round(int(x) / 10) * 10 for x in num_audio_tokens])
+ num_audios = len(media_embeds['sound']) # length of queue is the number of audios we have in total
+ max_audio_tokens, embed_dim = media_embeds['sound'][0].shape
+
+
+ audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
+ num_audio_tokens.device
+ ) < num_audio_tokens.unsqueeze(1)
+
+ audio_embeds = []
+ while media_embeds['sound']:
+ audio_embeds.append(media_embeds['sound'].popleft())
+ audio_embeds = torch.stack(audio_embeds,dim=0)
+
+ masked_audio_features = audio_embeds[audio_features_mask].view(-1, embed_dim)
+ batch_size, sequence_length = input_ids.shape
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
+
+ left_padding = True
+ if batch_size > 1:
+ if _left_padding and not _right_padding:
+ left_padding = True
+ elif not _left_padding and _right_padding:
+ left_padding = False
+ elif not _left_padding and not _right_padding:
+ # both side is 1, so cannot tell
+ left_padding = self.tokenizer.padding_side == "left"
+ else:
+ # invalid attention_mask
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
+
+ # 1. Create a mask to know where special audio tokens are
+ special_audio_token_mask = input_ids == self.tokenizer.media_token_ids['sound'] #hard coded to just work with 'sound'
+ num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1)
+
+ # In case the Audio model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = text_embeds.device
+ attention_mask = attention_mask.to(target_device)
+ input_ids = input_ids.to(target_device)
+ num_audio_tokens = num_audio_tokens.to(target_device)
+ batch_indices, non_audio_indices = torch.where(
+ (input_ids != self.tokenizer.media_token_ids['sound']) & (attention_mask == 1)
+ )
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged audio-text sequence.
+ # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
+ # `torch.cumsum` computes how each audio token shifts subsequent text token positions.
+ token_placeholder_num = torch.zeros_like(input_ids)
+ token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1
+ token_placeholder_num = token_placeholder_num + 1
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
+ max_token_num = token_placeholder_num.sum(-1).max()
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
+ batch_indices, non_audio_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_audio_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_token_num, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_token_num, dtype=attention_mask.dtype, device=text_embeds.device
+ )
+ final_input_ids = torch.full(
+ (batch_size, max_token_num), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=text_embeds.device
+ )
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
+ final_embedding[batch_indices, text_to_overwrite] = text_embeds[batch_indices, non_audio_indices]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
+ final_labels = None
+ if labels is not None:
+ labels = labels.to(target_device)
+ final_labels = torch.full_like(final_attention_mask, IGNORE_INDEX, dtype=torch.long) #.to(torch.long)
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_audio_indices]
+
+ # 5. Fill the embeddings corresponding to the audios. Anything that is still zeros needs filling
+ audio_to_overwrite = torch.full(
+ (batch_size, max_token_num), True, dtype=torch.bool, device=text_embeds.device
+ )
+ audio_to_overwrite[batch_indices, text_to_overwrite] = False
+ seq_indices = torch.arange(max_token_num).unsqueeze(0).to(target_device)
+ seq_indices = seq_indices.expand(batch_size, max_token_num)
+
+ if left_padding:
+ # exclude padding on the left
+ max_token_num = max_token_num.to(target_device)
+ val = (max_token_num - seq_indices) <= (
+ token_placeholder_num.sum(-1) - (attention_mask == 0).long().sum(-1)
+ )[:, None]
+ else:
+ # exclude padding on the right
+ val = seq_indices < (token_placeholder_num.sum(-1) - (attention_mask == 0).long().sum(-1))[:, None]
+
+ audio_to_overwrite &= val
+
+ if audio_to_overwrite.sum() != num_audio_tokens.sum():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of audio tokens is {num_special_audio_tokens} while"
+ f" the number of audio given to the model is {num_audios}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[audio_to_overwrite] = (
+ masked_audio_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ )
+ final_attention_mask |= audio_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+ # # Truncate sequences to `model_max_length` as media embeddings are inserted
+ inputs, labels = self.__truncate_sequence(final_embedding, final_labels)
+ return self.__batchify_sequence(inputs, labels)
+
+
+ def __embed_media_tokens(
+ self,
+ media: Dict[str, List[torch.Tensor]],
+ media_config: Dict[str, Dict[str, Any]],
+ media_meta: Dict[str, Dict[str, Any]]= None,
+ ) -> Dict[str, List[torch.Tensor]]:
+ embeds = defaultdict(deque)
+ for name in media:
+ if self.training:
+ # Gather metainfo of media objects from all ranks
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
+ infos = list(chain(*distributed.all_gather(info)))
+
+ # The entire batch does not contain any media objects of this type.
+ if not infos:
+ continue
+
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
+ if media.get(name) is None or len(media[name]) == 0:
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
+ continue
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name], media_meta['sound_feature_masks'])) # hard coded
+ return embeds
+
+ def __truncate_sequence(
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
+ return inputs, labels
+
+ def __batchify_sequence(
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ batch_size = len(inputs)
+ device = inputs[0].device
+ hidden_size = inputs[0].shape[1]
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
+
+ inputs_p, labels_p = [], []
+ for k in range(batch_size):
+ size_pk = max_length - inputs[k].shape[0]
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
+ if self.tokenizer.padding_side == "right":
+ attention_mask[k, inputs[k].shape[0] :] = False
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
+ else:
+ attention_mask[k, : -inputs[k].shape[0]] = False
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
+ inputs_p.append(inputs_pk)
+ labels_p.append(labels_pk)
+
+ inputs = torch.stack(inputs_p, dim=0)
+ labels = torch.stack(labels_p, dim=0)
+ return inputs, labels, attention_mask
+
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
+ # Handle sequence parallelism
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
+ if PROCESS_GROUP_MANAGER is not None:
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
+
+ bs, shard_seqlen = position_ids.shape
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
+
+ if sp_rank == 0:
+ original_start_id = 0
+ else:
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
+
+ # Gather attention_mask, position_ids, labels and input_embeds
+ all_inputs_embeds = torch.zeros(
+ bs,
+ torch.sum(sp_seq_len_cat),
+ inputs_embeds.shape[-1],
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ ).contiguous()
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
+ dist.barrier(group=sp_group)
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
+ dist.barrier(group=sp_group)
+
+ attention_mask_list = [
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
+ for i in range(sp_degree)
+ ]
+ position_ids_list = [
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
+ for i in range(sp_degree)
+ ]
+ labels_list = [
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
+ ]
+
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
+ dist.all_gather(labels_list, labels, group=sp_group)
+
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
+
+ global_attention_mask_list = []
+ global_position_ids_list = []
+ global_labels_list = []
+ global_inputs_embeds_list = []
+ for i in range(bs):
+ global_attention_mask_batch_list = []
+ global_position_ids_batch_list = []
+ global_labels_batch_list = []
+ global_inputs_embeds_batch_list = []
+ for j in range(sp_degree):
+ eff_len = effective_seqlen_batch_list[i][j]
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
+
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
+
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
+ global_attention_mask_list, batch_first=True, padding_value=False
+ )
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
+ global_position_ids_list, batch_first=True, padding_value=-1
+ )
+ global_labels = torch.nn.utils.rnn.pad_sequence(
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
+ )
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
+ global_inputs_embeds_list, batch_first=True, padding_value=0
+ )
+
+ # Re-shard the inputs
+ if ring_degree > 1:
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
+ assert torch.all(
+ total_effective_seqlen % sp_degree == 0
+ ), "total_effective_seqlen must be divisible by sp_degree"
+
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
+
+ new_attention_mask = torch.zeros(
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
+ )
+ new_position_ids = torch.zeros(
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
+ )
+ new_labels = torch.full(
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
+ )
+ new_inputs_embeds = torch.zeros(
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
+ dtype=global_inputs_embeds.dtype,
+ device=global_inputs_embeds.device,
+ )
+
+ if ring_type == "ring_varlen":
+ for i in range(bs):
+ start_idx = new_seqlen_per_rank[i] * sp_rank
+ end_idx = start_idx + new_seqlen_per_rank[i]
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
+ i, start_idx:end_idx, :
+ ]
+ elif ring_type == "zigzag_ring_varlen":
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
+ for i in range(bs):
+ # Zigzag pattern indices
+ if sp_degree == ring_degree:
+ forward_rank_idx = sp_rank
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
+ else:
+ ulysses_offset = ulysses_rank * ring_degree * 2
+ forward_rank_idx = ring_rank + ulysses_offset
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
+
+ # Calculate start and end indices for the forward and backward zigzag
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
+
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
+
+ # Fill new tensors with zigzag data
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
+ i, start_idx_bwd:end_idx_bwd
+ ]
+
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
+ i, start_idx_bwd:end_idx_bwd
+ ]
+
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
+
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
+ i, start_idx_bwd:end_idx_bwd, :
+ ]
+ else:
+ raise ValueError(f"Invalid ring_type: {ring_type}")
+ else:
+ global_seq_len = global_attention_mask.shape[-1]
+ seq_len_sharded = global_seq_len // sp_degree
+ start_idx_reshard = seq_len_sharded * sp_rank
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
+
+ new_attention_mask = torch.narrow(
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
+ )
+ new_position_ids = torch.narrow(
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
+ )
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
+ new_inputs_embeds = torch.narrow(
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
+ )
+
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
+
+ device = inputs_embeds.device
+ batch_size = inputs_embeds.shape[0]
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
+
+ # Pack all sequences together
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
+
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
+
+ # Mask the first token of each sequence to avoid contamination
+ for label in labels_p:
+ label[0] = IGNORE_INDEX
+
+ # Batch the data
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
+
+ if hasattr(
+ self, "pad_to_multiple_of"
+ ): # related to quantization, please refer to ModelArguments for more information.
+ assert len(labels_p.shape) == 2
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
+ hidden_size = inputs_embeds_p.shape[-1]
+
+ if max_length % self.pad_to_multiple_of != 0:
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
+ difference = max_length - cur_length
+
+ inputs_embeds_p = torch.cat(
+ (
+ inputs_embeds_p,
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
+ ),
+ dim=1,
+ )
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
+ attention_mask_p = torch.cat(
+ (
+ attention_mask_p,
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
+ ),
+ dim=1,
+ )
+ position_ids_p = torch.cat(
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
+ )
+
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
+
+ def get_xgr_logits_processor(self, response_format: ResponseFormat) -> List[LogitsProcessor]:
+ # Convert response format to logits processor
+ import xgrammar as xgr
+
+ logging.info("[XGrammar] Compiling grammar for contrained output")
+
+ if self.grammar_compiler is None:
+ # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
+ self.grammar_compiler = xgr.GrammarCompiler(
+ xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
+ )
+
+ if response_format.type == "json_schema":
+ compiled_grammar = self.grammar_compiler.compile_json_schema(
+ response_format.json_schema.schema_,
+ indent=2,
+ )
+ else:
+ compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
+
+ return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ input_ids: Optional[torch.FloatTensor] = None,
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
+ media_config: Dict[str, Dict[str, Any]] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ media_meta: Dict[str, Dict[str, Any]]= None,
+ **generation_kwargs,
+ ):
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask, media_meta)
+ return self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=1024, **generation_kwargs)
+
+ @torch.inference_mode()
+ def generate_content(
+ self,
+ prompt: Union[str, List],
+ generation_config: Optional[GenerationConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ ) -> str:
+ # TODO(zhijianl): Support directly taking conversation as input
+ conversation = [{"from": "human", "value": prompt}]
+
+ # Convert response format to logits processor
+ if response_format:
+ xgr_logits_processor = self.get_xgr_logits_processor(response_format)
+ else:
+ xgr_logits_processor = None
+
+ # Extract media from the conversation
+
+ # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
+ media, media_meta = extract_media(conversation, self.config)
+
+ # Process media
+ media_config = defaultdict(dict)
+ for name in media:
+ if name == "sound":
+ sounds = process_sounds(media["sound"]).half()
+ media[name] = [sound for sound in sounds]
+ sound_feature_masks = process_sound_masks(media_meta["sound_feature_masks"]).half()
+ media_meta["sound_feature_masks"] = [sound_mask for sound_mask in sound_feature_masks]
+ sound_embed_masks = process_sound_masks(media_meta["sound_embed_masks"]).half()
+ media_meta["sound_embed_masks"] = [sound_mask for sound_mask in sound_embed_masks]
+ else:
+ raise ValueError(f"Unsupported media type: {name}")
+
+
+ # Tokenize the conversation
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
+
+ # Set up the generation config
+ generation_config = generation_config or self.default_generation_config
+
+ # Generate the response
+ try:
+ output_ids = self.generate(
+ input_ids=input_ids,
+ media=media,
+ media_config=media_config,
+ media_meta=media_meta,
+ generation_config=generation_config,
+ logits_processor=xgr_logits_processor, # structured generation
+ )
+ except ValueError:
+ if not generation_config.do_sample:
+ raise
+ # FIXME(zhijianl): This is a temporary workaround for the sampling issue
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
+ generation_config.do_sample = False
+ output_ids = self.generate(
+ input_ids=input_ids,
+ media=media,
+ media_config=media_config,
+ media_meta=media_meta,
+ generation_config=generation_config,
+ logits_processor=xgr_logits_processor,
+ )
+
+ # Decode the response
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
+ return response
+
+ @property
+ def default_generation_config(self) -> GenerationConfig:
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
+ if self.tokenizer.eos_token_id is None:
+ raise ValueError("Tokenizer must have an EOS token")
+ if generation_config.max_length == GenerationConfig().max_length:
+ generation_config.max_length = self.tokenizer.model_max_length
+ if generation_config.pad_token_id is None:
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
+ if generation_config.bos_token_id is None:
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
+ if generation_config.eos_token_id is None:
+ generation_config.eos_token_id = self.tokenizer.stop_token_ids
+ return generation_config
diff --git a/llava/model/loss.py b/llava/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7b8a2d3be72069fe6370fed80ccf1086ff3877
--- /dev/null
+++ b/llava/model/loss.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from typing import List, Union
+
+import torch
+from torch.nn.functional import cross_entropy
+
+from llava.constants import IGNORE_INDEX
+
+__all__ = ["soft_cross_entropy"]
+
+
+def soft_cross_entropy(
+ outputs: torch.Tensor,
+ targets: torch.Tensor,
+ soft_tokens: Union[torch.Tensor, List[int]],
+ std: float = 1,
+ ignore_index: int = IGNORE_INDEX,
+) -> torch.Tensor:
+ # Remove last token from outputs and first token from targets
+ outputs = outputs[..., :-1, :].contiguous()
+ targets = targets[..., 1:].contiguous()
+
+ # Flatten outputs and targets
+ targets = targets.view(-1)
+ outputs = outputs.view(targets.size(0), -1)
+
+ # Remove outputs and targets with ignore_index
+ indices = targets != ignore_index
+ outputs = outputs[indices]
+ targets = targets[indices]
+
+ # Convert soft token IDs to tensor
+ if isinstance(soft_tokens, list):
+ soft_tokens = torch.tensor(soft_tokens).to(targets)
+
+ # Calculate loss for non-soft tokens
+ indices = torch.isin(targets, soft_tokens, invert=True)
+ loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
+
+ # Calculate loss for soft tokens
+ indices = torch.isin(targets, soft_tokens)
+ targets_indices = torch.zeros_like(outputs[indices])
+ for k, target in enumerate(targets[indices]):
+ dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
+ targets_indices[k][soft_tokens] = dist / dist.sum()
+ loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
+
+ # Return average loss
+ return loss / targets.size(0)
diff --git a/llava/model/make_delta.py b/llava/model/make_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..420cb41e9ee3c6e6ab9f505f15bf48fbe9ecaf5c
--- /dev/null
+++ b/llava/model/make_delta.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+"""
+Usage:
+python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llava.model.utils import auto_upgrade
+
+
+def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading target model")
+ auto_upgrade(target_model_path)
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Calculating delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ if name not in base.state_dict():
+ assert name in [
+ "model.mm_projector.weight",
+ "model.mm_projector.bias",
+ ], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data -= base.state_dict()[name]
+ else:
+ assert name in [
+ "model.embed_tokens.weight",
+ "lm_head.weight",
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
+
+ print("Saving delta")
+ if hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, default=None)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
diff --git a/llava/model/multimodal_encoder/afwhisper_audio_encoder.py b/llava/model/multimodal_encoder/afwhisper_audio_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..146259962fb1502d99ac591bf6ef2139514fa12d
--- /dev/null
+++ b/llava/model/multimodal_encoder/afwhisper_audio_encoder.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+import torch
+from transformers import PretrainedConfig
+from llava.model.multimodal_encoder.modeling_whisper import AFWhisperEncoder
+from llava.model.multimodal_encoder.sound_encoder import SoundTower
+
+class AFWhisperSoundTower(SoundTower):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+ super().__init__(model_name_or_path, config)
+ self.sound_tower = AFWhisperEncoder.from_pretrained(model_name_or_path)
+ self.is_loaded = True
diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca0622095518c9c61f68fb46dd4e8409d06e4594
--- /dev/null
+++ b/llava/model/multimodal_encoder/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+import os
+
+from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
+from .whisper_encoder import WhisperSpeechTower
+from .afwhisper_audio_encoder import AFWhisperSoundTower
+
+def build_speech_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
+ if model_name_or_path is None:
+ return None
+ speech_tower = WhisperSpeechTower(model_name_or_path, config)
+ config.speech_hidden_size = speech_tower.config.hidden_size
+ return speech_tower
+
+def build_sound_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
+ if model_name_or_path is None:
+ return None
+ sound_tower = AFWhisperSoundTower(model_name_or_path, config)
+ config.sound_hidden_size = 1280
+ return sound_tower
+
diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d85d2c8939fc9426b1b3cad06876efb57642f60
--- /dev/null
+++ b/llava/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModel, PretrainedConfig
+
+from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
+
+
+class CLIPVisionTower(VisionTower):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+ super().__init__(model_name_or_path, config)
+ self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
+ self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
+ self.is_loaded = True
+
+
+class CLIPVisionTowerS2(VisionTowerS2):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+ super().__init__(model_name_or_path, config)
+ self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
+ self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
+
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
+ self.image_processor.size["shortest_edge"] = self.scales[-1]
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.scales[-1]
+
+ self.is_loaded = True
diff --git a/llava/model/multimodal_encoder/image_processor.py b/llava/model/multimodal_encoder/image_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bccfcef8f7b83ef3b584a6a20ecadb6ab54c2fe1
--- /dev/null
+++ b/llava/model/multimodal_encoder/image_processor.py
@@ -0,0 +1,552 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Image processor class for RADIO."""
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+from PIL.Image import Image
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from transformers.image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
+from transformers.image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from transformers.utils import (
+ TensorType,
+ is_tf_available,
+ is_torch_available,
+ is_torchvision_available,
+ logging,
+ requires_backends,
+)
+
+if is_torch_available():
+ import torch
+ import torch.nn.functional as F
+
+if is_torchvision_available():
+ from torchvision.ops.boxes import batched_nms
+
+# if is_tf_available():
+# import tensorflow as tf
+# from tensorflow.experimental import numpy as tnp
+
+# from ...tf_utils import flatten, shape_list
+
+logger = logging.get_logger(__name__)
+
+
+def rank_print(s):
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ print(f"[Rank {rank}] {s}")
+
+
+class ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs an image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
+ Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
+ to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the
+ `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
+ `preprocess` method.
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
+ method.
+ pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`):
+ Value of padded pixels.
+ pad_multiple (`int`, *optional*, defaults to `None`):
+ Pad to a multiple of specified number.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: bool = True,
+ pad_size: int = None,
+ pad_multiple: int = None,
+ pad_value: Optional[Union[float, List[float]]] = 0.0,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ x = 0
+ size = size if size is not None else {"longest_edge": 1024}
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
+
+ if pad_size is not None and pad_multiple is not None:
+ raise ValueError("pad_size and pad_multiple should not be set at the same time.")
+
+ pad_size = (
+ pad_size if pad_size is not None else {"height": 1024, "width": 1024} if pad_multiple is not None else None
+ )
+ if do_pad:
+ pad_size = get_size_dict(pad_size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_pad = do_pad
+ self.pad_multiple = pad_multiple
+ self.pad_size = pad_size
+ self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value
+ self.do_convert_rgb = do_convert_rgb
+ self._valid_processor_keys = [
+ "images",
+ "segmentation_maps",
+ "do_resize",
+ "size",
+ "resample",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_pad",
+ "pad_size",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ pad_size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ pad_size (`Dict[str, int]`):
+ Size of the output image after padding.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
+ `data_format` of the `image` will be used.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ output_height, output_width = pad_size["height"], pad_size["width"]
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+
+ pad_width = output_width - input_width
+ pad_height = output_height - input_height
+
+ padded_image = pad(
+ image,
+ ((0, pad_height), (0, pad_width)),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ constant_values=self.pad_value,
+ **kwargs,
+ )
+ return padded_image
+
+ def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
+ """
+ Compute the output size given input size and target long side length.
+ """
+ oldh, oldw = old_shape
+ scale = longest_edge * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ newh = int(newh + 0.5)
+ neww = int(neww + 0.5)
+ return (newh, neww)
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size
+ of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
+ to that size, possibly changing the aspect ratio.
+ resample:
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "longest_edge" not in size:
+ if "width" not in size or "height" not in size:
+ raise ValueError(
+ f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}"
+ )
+ input_size = get_image_size(image, channel_dim=input_data_format)
+ if "longest_edge" in size:
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
+ else:
+ output_height, output_width = size["height"], size["width"]
+ return resize(
+ image,
+ size=(output_height, output_width),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_resize: bool,
+ do_rescale: bool,
+ do_normalize: bool,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = None,
+ rescale_factor: Optional[float] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[Dict[str, int]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+
+ if do_pad:
+ if self.pad_multiple:
+ h, w = get_image_size(image, channel_dim=input_data_format)
+ pad_size = {
+ "height": math.ceil(h / self.pad_multiple) * self.pad_multiple,
+ "width": math.ceil(w / self.pad_multiple) * self.pad_multiple,
+ }
+
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
+
+ return image, reshaped_input_size
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[Dict[str, int]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
+ # image = to_numpy_array(image)
+
+ # import time
+ # if int(time.time()*1000) % 10 == 0:
+ # # create an PIL image of size 1x1
+ # image = PIL.Image.new('RGB', (1, 1))
+
+ if isinstance(image, Image):
+ # PIL always uses Channels Last.
+ input_data_format = ChannelDimension.LAST
+
+ # PIL RGBA images are converted to RGB
+ # mode_before = image.mode
+ if do_convert_rgb:
+ image = convert_to_rgb(image)
+
+ # All transformations expect numpy arrays.
+ image_ = image
+ image = to_numpy_array(image)
+
+ # if isinstance(image_, np.ndarray):
+ # rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}")
+ # elif isinstance(image_, Image):
+ # rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}")
+ # else:
+ # rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}")
+
+ if len(image.shape) == 2:
+ h, w = image.shape
+ ret = np.empty((h, w, 3), dtype=np.uint8)
+ ret[:, :, 0] = image
+ ret[:, :, 1] = image
+ ret[:, :, 2] = image
+ image = ret
+ rank_print(f"preprocess new image shape={image.shape}")
+ elif len(image.shape) == 3 and image.shape[-1] == 1:
+ ret = np.empty((h, w, 3), dtype=np.uint8)
+ ret[:, :, 0] = image[:, :, 0]
+ ret[:, :, 1] = image[:, :, 0]
+ ret[:, :, 2] = image[:, :, 0]
+ image = ret
+ rank_print(f"preprocess new image shape={image.shape}")
+
+ if is_scaled_image(image) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ original_size = get_image_size(image, channel_dim=input_data_format)
+
+ image, reshaped_input_size = self._preprocess(
+ image=image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ pad_size=pad_size,
+ input_data_format=input_data_format,
+ )
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ # rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}")
+
+ # if image is a single channel convert to rgb
+ if do_convert_rgb and image.shape[0] == 1:
+ c, h, w = image.shape
+ ret = np.empty((3, h, w), dtype=np.uint8)
+ ret[0, :, :] = image[0, :, :]
+ ret[1, :, :] = image[0, :, :]
+ ret[2, :, :] = image[0, :, :]
+ image = ret
+ rank_print(f"preprocess final: {image.shape}")
+
+ return image, original_size, reshaped_input_size
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional["PILImageResampling"] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[Union[int, float]] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[Dict[str, int]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
+ `size["longest_edge"]` whilst preserving the aspect ratio.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image pixel values by rescaling factor.
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to apply to the image pixel values.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image.
+ pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
+ `pad_size["width"]` if `do_pad` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ if do_pad:
+ pad_size = get_size_dict(pad_size, default_to_square=True)
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ images, original_sizes, reshaped_input_sizes = zip(
+ *(
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ pad_size=pad_size,
+ do_convert_rgb=do_convert_rgb,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in images
+ )
+ )
+
+ data = {
+ "pixel_values": images,
+ "original_sizes": original_sizes,
+ "reshaped_input_sizes": reshaped_input_sizes,
+ }
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/llava/model/multimodal_encoder/imagebind_encoder.py b/llava/model/multimodal_encoder/imagebind_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f66811c9cfd57f246d941a9b9843366793320a67
--- /dev/null
+++ b/llava/model/multimodal_encoder/imagebind_encoder.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+import torch
+from transformers import PretrainedConfig
+from llava.model.ImageBind.models import imagebind_model
+from llava.model.multimodal_encoder.sound_encoder import SoundTower
+
+class ImagebindSoundTower(SoundTower):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+ super().__init__(model_name_or_path, config)
+ self.sound_tower,_ = imagebind_model.imagebind_huge()
+ self.sound_tower.load_state_dict(torch.load(model_name_or_path))
+ self.is_loaded = True
diff --git a/llava/model/multimodal_encoder/intern/configuration_intern_vit.py b/llava/model/multimodal_encoder/intern/configuration_intern_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..10520a2a4e38f387e8df9ca80419206ddb8b2220
--- /dev/null
+++ b/llava/model/multimodal_encoder/intern/configuration_intern_vit.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2023 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class InternVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of color channels in the input images (e.g., 3 for RGB).
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ qkv_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries and values in the self-attention layers.
+ hidden_size (`int`, *optional*, defaults to 3200):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_attention_heads (`int`, *optional*, defaults to 25):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 12800):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ qk_normalization (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the queries and keys in the self-attention layers.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
+ Whether to use flash attention mechanism.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Dropout rate for stochastic depth.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 0.1):
+ A factor for layer scale.
+ """
+
+ model_type = "intern_vit_6b"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=14,
+ image_size=224,
+ qkv_bias=False,
+ hidden_size=3200,
+ num_attention_heads=25,
+ intermediate_size=12800,
+ qk_normalization=True,
+ num_hidden_layers=48,
+ use_flash_attn=True,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ dropout=0.0,
+ drop_path_rate=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=0.1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.drop_path_rate = drop_path_rate
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.qkv_bias = qkv_bias
+ self.qk_normalization = qk_normalization
+ self.use_flash_attn = use_flash_attn
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if "vision_config" in config_dict:
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
diff --git a/llava/model/multimodal_encoder/intern/flash_attention.py b/llava/model/multimodal_encoder/intern/flash_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..d28326a56832fcdb816a9df13de8443a84d5e092
--- /dev/null
+++ b/llava/model/multimodal_encoder/intern/flash_attention.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+try: # v1
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except: # v2
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+
+from flash_attn.bert_padding import pad_input, unpad_input
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
+ if unpadded: (nnz, 3, h, d)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert not need_weights
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
+ assert qkv.is_cuda
+
+ if cu_seqlens is None:
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+ max_s = seqlen
+ cu_seqlens = torch.arange(
+ 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
+ )
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv,
+ cu_seqlens,
+ max_s,
+ self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale,
+ causal=causal,
+ )
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad,
+ cu_seqlens,
+ max_s,
+ self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale,
+ causal=causal,
+ )
+ output = rearrange(
+ pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
+ "b s (h d) -> b s h d",
+ h=nheads,
+ )
+ else:
+ assert max_s is not None
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv,
+ cu_seqlens,
+ max_s,
+ self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale,
+ causal=causal,
+ )
+
+ return output, None
diff --git a/llava/model/multimodal_encoder/intern/modeling_intern_vit.py b/llava/model/multimodal_encoder/intern/modeling_intern_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ea886e581edfbdb28da7989b4f3451ed6a9e6f
--- /dev/null
+++ b/llava/model/multimodal_encoder/intern/modeling_intern_vit.py
@@ -0,0 +1,549 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2023 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
+
+from .flash_attention import FlashAttention
+
+has_flash_attn = True
+
+
+logger = logging.get_logger(__name__)
+
+
+""" DropBlock, DropPath
+
+PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
+
+Papers:
+DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
+
+Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
+
+Code:
+DropBlock impl inspired by two Tensorflow impl that I liked:
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
+ """generate N-D grid in dimension order.
+
+ The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
+
+ That is, the statement
+ [X1,X2,X3] = ndgrid(x1,x2,x3)
+
+ produces the same result as
+
+ [X2,X1,X3] = meshgrid(x2,x1,x3)
+
+ This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
+ torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
+
+ """
+ try:
+ return torch.meshgrid(*tensors, indexing="ij")
+ except TypeError:
+ # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
+ # the old behaviour of meshgrid was 'ij'
+ return torch.meshgrid(*tensors)
+
+
+def drop_block_2d(
+ x,
+ drop_prob: float = 0.1,
+ block_size: int = 7,
+ gamma_scale: float = 1.0,
+ with_noise: bool = False,
+ inplace: bool = False,
+ batchwise: bool = False,
+):
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ # seed_drop_rate, the gamma parameter
+ gamma = (
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
+ )
+
+ # Forces the block to be inside the feature map.
+ w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & (
+ (h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)
+ )
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
+
+ if batchwise:
+ # one mask for whole batch, quite a bit faster
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
+ else:
+ uniform_noise = torch.rand_like(x)
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
+ block_mask = -F.max_pool2d(
+ -block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
+ )
+
+ if with_noise:
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
+ if inplace:
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
+ else:
+ x = x * block_mask + normal_noise * (1 - block_mask)
+ else:
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+def drop_block_fast_2d(
+ x: torch.Tensor,
+ drop_prob: float = 0.1,
+ block_size: int = 7,
+ gamma_scale: float = 1.0,
+ with_noise: bool = False,
+ inplace: bool = False,
+):
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
+ block mask at edges.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ gamma = (
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
+ )
+
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
+ block_mask = F.max_pool2d(
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
+ )
+
+ if with_noise:
+ normal_noise = torch.empty_like(x).normal_()
+ if inplace:
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
+ else:
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
+ else:
+ block_mask = 1 - block_mask
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+class DropBlock2d(nn.Module):
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
+
+ def __init__(
+ self,
+ drop_prob: float = 0.1,
+ block_size: int = 7,
+ gamma_scale: float = 1.0,
+ with_noise: bool = False,
+ inplace: bool = False,
+ batchwise: bool = False,
+ fast: bool = True,
+ ):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.gamma_scale = gamma_scale
+ self.block_size = block_size
+ self.with_noise = with_noise
+ self.inplace = inplace
+ self.batchwise = batchwise
+ self.fast = fast # FIXME finish comparisons of fast vs not
+
+ def forward(self, x):
+ if not self.training or not self.drop_prob:
+ return x
+ if self.fast:
+ return drop_block_fast_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
+ )
+ else:
+ return drop_block_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
+ )
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class InternRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+try:
+ from apex.normalization import FusedRMSNorm
+
+ InternRMSNorm = FusedRMSNorm # noqa
+
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
+except ImportError:
+ # using the normal InternRMSNorm
+ pass
+except Exception:
+ logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
+ pass
+
+
+class InternVisionEmbeddings(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.randn(1, 1, self.embed_dim),
+ )
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
+ return embeddings
+
+
+class InternAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
+ if config.use_flash_attn and not has_flash_attn:
+ print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale = self.head_dim**-0.5
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
+ self.attn_drop = nn.Dropout(config.attention_dropout)
+ self.proj_drop = nn.Dropout(config.dropout)
+
+ self.qk_normalization = config.qk_normalization
+
+ if self.qk_normalization:
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ if self.use_flash_attn:
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False)
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
+ return x
+
+
+class InternMLP(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.act = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InternVisionEncoderLayer(nn.Module):
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.attn = InternAttention(config)
+ self.mlp = InternMLP(config)
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ """
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
+
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
+
+ return hidden_states
+
+
+class InternVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`InternEncoderLayer`].
+
+ Args:
+ config (`InternConfig`):
+ The corresponding vision configuration for the `InternEncoder`.
+ """
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
+ self.layers = nn.ModuleList(
+ [InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = True
+
+ def forward(
+ self,
+ inputs_embeds,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ hidden_states = inputs_embeds
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ )
+ hidden_states = layer_outputs
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
+
+
+class InternVisionModel(PreTrainedModel):
+ main_input_name = "pixel_values"
+ config_class = InternVisionConfig
+ _no_split_modules = ["InternVisionEncoderLayer"]
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = InternVisionEmbeddings(config)
+ self.encoder = InternVisionEncoder(config)
+
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
+ pos_emb = self.embeddings.position_embedding
+ _, num_positions, embed_dim = pos_emb.shape
+ cls_emb = pos_emb[:, :1, :]
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False)
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
+ logger.info(f"Resized position embeddings from {old_size} to {new_size}")
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_embeds: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None and pixel_embeds is None:
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
+
+ if pixel_embeds is not None:
+ hidden_states = pixel_embeds
+ else:
+ if len(pixel_values.shape) == 4:
+ hidden_states = self.embeddings(pixel_values)
+ else:
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = encoder_outputs.last_hidden_state
+ pooled_output = last_hidden_state[:, 0, :]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/llava/model/multimodal_encoder/intern_encoder.py b/llava/model/multimodal_encoder/intern_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a449e3f0c39b47b0acd07a33809a78454900193e
--- /dev/null
+++ b/llava/model/multimodal_encoder/intern_encoder.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torchvision.transforms as T
+from torchvision.transforms.functional import InterpolationMode
+from transformers import AutoConfig, AutoModel
+from transformers.image_processing_utils import BaseImageProcessor
+
+from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
+from llava.model.multimodal_encoder.intern.modeling_intern_vit import InternVisionModel
+from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
+
+
+def build_transform(input_size):
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ ]
+ )
+ return transform
+
+
+class InternVisionPreprocessor(BaseImageProcessor):
+ def __init__(self, resize_size=448):
+ super().__init__()
+ self.resize_size = resize_size
+
+ @property
+ def size(self):
+ return {"height": self.resize_size, "width": self.resize_size}
+
+ def preprocess(self, image, return_tensors):
+ transform = build_transform(self.resize_size)
+ if isinstance(image, list):
+ image_tensor = [transform(img) for img in image]
+ return {"pixel_values": image_tensor}
+ else:
+ image_tensor = transform(image)
+ return {"pixel_values": [image_tensor]}
+
+
+class InternVisionTower(VisionTower):
+ def __init__(self, vision_tower, config, drop_path_rate=0.0):
+ super().__init__(vision_tower, config)
+ self._drop_path_rate = drop_path_rate
+
+ self.image_processor = InternVisionPreprocessor()
+ vision_config = InternVisionConfig.from_pretrained(vision_tower)
+ vision_config.drop_path_rate = self._drop_path_rate
+ self.vision_tower = InternVisionModel.from_pretrained(
+ vision_tower, torch_dtype=eval(config.model_dtype), config=vision_config
+ )
+
+ self.is_loaded = True
+
+
+class InternVisionTowerS2(VisionTowerS2):
+ def __init__(self, vision_tower, config, drop_path_rate=0.0):
+ super().__init__(vision_tower, config)
+ self._drop_path_rate = drop_path_rate
+
+ self.image_processor = InternVisionPreprocessor(resize_size=self.scales[-1])
+ vision_config = InternVisionConfig.from_pretrained(vision_tower)
+ vision_config.drop_path_rate = self._drop_path_rate
+ self.vision_tower = InternVisionModel.from_pretrained(
+ vision_tower, torch_dtype=eval(config.model_dtype), config=vision_config
+ )
+
+ self.is_loaded = True
+
+
+AutoConfig.register("intern_vit_6b", InternVisionConfig)
+AutoModel.register(InternVisionConfig, InternVisionModel)
diff --git a/llava/model/multimodal_encoder/modeling_whisper.py b/llava/model/multimodal_encoder/modeling_whisper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f15f687fdf1c77fe323435c29051d624d967eb7d
--- /dev/null
+++ b/llava/model/multimodal_encoder/modeling_whisper.py
@@ -0,0 +1,1366 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2Audio model."""
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_outputs import BaseModelOutput, ModelOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.auto import AutoModel, AutoModelForCausalLM
+from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
+
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Qwen2AudioConfig"
+
+
+@dataclass
+class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Qwen2Audio causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Attentions mask, used to update attention mask and position_ids.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ attention_mask: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention with Whisper->Qwen2Audio
+class Qwen2AudioAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ layer_idx: Optional[int] = None,
+ config: Optional[Qwen2AudioConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ if layer_idx is None and is_decoder:
+ logger.warning_once(
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
+ "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+ self.layer_idx = layer_idx
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_value.is_updated[self.layer_idx] = True
+ past_key_value = past_key_value.cross_attention_cache
+ else:
+ past_key_value = past_key_value.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_value and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_probs, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio
+class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
+ """
+ Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
+ "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
+ )
+ # Qwen2AudioFlashAttention2 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_value.is_updated[self.layer_idx] = True
+ past_key_value = past_key_value.cross_attention_cache
+ else:
+ past_key_value = past_key_value.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_value and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, : key_states.shape[-2]]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ causal_mask,
+ tgt_len,
+ dropout=self.dropout if self.training else 0.0,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio
+class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_value.is_updated[self.layer_idx] = True
+ past_key_value = past_key_value.cross_attention_cache
+ else:
+ past_key_value = past_key_value.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_value and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
+
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+QWEN2AUDIO_ATTENTION_CLASSES = {
+ "eager": Qwen2AudioAttention,
+ "flash_attention_2": Qwen2AudioFlashAttention2,
+ "sdpa": Qwen2AudioSdpaAttention,
+}
+
+
+# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO
+class Qwen2AudioEncoderLayer(nn.Module):
+ def __init__(self, config: Qwen2AudioConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = QWEN2AUDIO_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+QWEN2AUDIO_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2AudioConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Qwen2Audio Model outputting raw hidden-states without any specific head on top.",
+ QWEN2AUDIO_START_DOCSTRING,
+)
+class Qwen2AudioPreTrainedModel(PreTrainedModel):
+ config_class = Qwen2AudioConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2AudioAttention"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ # important: this ported version of Qwen2Audio isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_config.init_std
+
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+QWEN2AUDIOENCODER_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2AudioEncoderConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ """The audio model from Qwen2Audio without any head or projection on top.""",
+ QWEN2AUDIOENCODER_START_DOCSTRING,
+)
+# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoder with Whisper->Qwen2Audio
+class AFWhisperEncoder(Qwen2AudioPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`Qwen2AudioEncoderLayer`].
+
+ Args:
+ config: Qwen2AudioEncoderConfig
+ """
+
+ # Ignore copy
+ config_class = Qwen2AudioEncoderConfig
+ main_input_name = "input_features"
+ _no_split_modules = ["Qwen2AudioEncoderLayer"]
+
+ def __init__(self, config: Qwen2AudioEncoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.num_mel_bins = config.num_mel_bins
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_source_positions
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
+
+ self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
+ self.embed_positions.requires_grad_(False)
+
+ self.layers = nn.ModuleList([Qwen2AudioEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+ # Ignore copy
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
+ Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
+ and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ attention_mask (`torch.Tensor`)`, *optional*):
+ Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility,
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
+ if input_features.shape[-1] != expected_seq_length:
+ raise ValueError(
+ f"Qwen2Audio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Ignore copy
+ input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
+
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
+ embed_pos = self.embed_positions.weight
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (
+ len(self.layers)
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ # Ignore copy
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Ignore copy
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states = self.avg_pooler(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 1)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ # Ignore copy
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+ input_lengths = (input_lengths - 1) // 2 + 1
+ output_lengths = (input_lengths - 2) // 2 + 1
+ return input_lengths, output_lengths
+
+
+class Qwen2AudioMultiModalProjector(nn.Module):
+ def __init__(self, config: Qwen2AudioConfig):
+ super().__init__()
+ self.linear = nn.Linear(config.audio_config.d_model, config.text_config.hidden_size, bias=True)
+
+ def forward(self, audio_features):
+ hidden_states = self.linear(audio_features)
+ return hidden_states
+
+
+QWEN2AUDIO_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`):
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """The QWEN2AUDIO model which consists of a audio backbone and a language model.""",
+ QWEN2AUDIO_START_DOCSTRING,
+)
+class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
+ def __init__(self, config: Qwen2AudioConfig):
+ super().__init__(config)
+ self.audio_tower = AutoModel.from_config(config.audio_config)
+
+ self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
+ self.post_init()
+
+ @property
+ def padding_side(self):
+ return self._padding_side
+
+ @padding_side.setter
+ def padding_side(self, padding_side: str):
+ if padding_side not in ["left", "right"]:
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
+ self._padding_side = padding_side
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
+ def tie_weights(self):
+ return self.language_model.tie_weights()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _merge_input_ids_with_audio_features(
+ self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels
+ ):
+ """
+ Merge input_ids with with audio features into final embeddings
+
+ Args:
+ audio_features (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
+ All audio vectors of all audios in the batch
+ num_audio_tokens (`torch.LongTensor` of shape `(num_audios)`):
+ The length of audio embeddings of each audio as stacked in `audio_features`
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
+ Token embeddings before merging with audio embeddings
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Input_ids of tokens, possibly filled with audio token
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Mask to avoid performing attention on padding token indices.
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
+ labels need to be recalculated to support training (if provided)
+ Returns:
+ final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
+
+ Explanation:
+ each audio has variable length embeddings, with length specified by num_audio_tokens
+ audio_features is concatenation of all audio embed vectors
+ task: fill each <|AUDIO|> with the correct number of audio embeddings
+ Example:
+ X (5 tokens), Y (3 tokens), Z (8 tokens)
+ X, Y are in the same sequence (in-context learning)
+ if right padding
+ input_ids: [
+ a b c d e f X g h i j k Y l m
+ o p q r Z s t u v _ _ _ _ _ _
+ ]
+ input_ids should be: [
+ a b c d e f X X X X X g h i j k Y Y Y l m
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
+ ]
+ labels should be: [
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
+ ]
+ elif left padding
+ input_ids: [
+ a b c d e f X g h i j k Y l m
+ _ _ _ _ _ _ o p q r Z s t u v
+ ]
+ input_ids should be: [
+ a b c d e f X X X X X g h i j k Y Y Y l m
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
+ ]
+ labels should be: [
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
+ ]
+ Edge cases:
+ * If tokens are same but audio token sizes are different, then cannot infer left or right padding
+ ```python
+ url1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
+ audio1, _ = librosa.load(BytesIO(urlopen(url1).read()), sr=processor.feature_extractor.sampling_rate)
+ url2 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
+ audio2, _ = librosa.load(BytesIO(urlopen(url2).read()), sr=processor.feature_extractor.sampling_rate)
+ prompts = [
+ "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]",
+ "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]",
+ ]
+ inputs = processor(text=prompts, audios=[audio1, audio2], return_tensors='pt', padding=True).to("cuda")
+ audio1 has 101 tokens, while audio2 has 72 tokens
+ ```
+
+ input_ids: [
+ a b c d X g h
+ i j Y k l m n
+ ]
+ where X is 3 tokens while Y is 5, this mean after merge
+ if left-padding (batched generation)
+ input_ids should be: [
+ _ _ a b c d X X X g h
+ i j Y Y Y Y Y k l m n
+ ]
+ elif (right padding) (training)
+ input_ids should be: [
+ a b c d X X X g h _ _
+ i j Y Y Y Y Y k l m n
+ ]
+ """
+ num_audios, max_audio_tokens, embed_dim = audio_features.shape
+ audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
+ num_audio_tokens.device
+ ) < num_audio_tokens.unsqueeze(1)
+ masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
+ batch_size, sequence_length = input_ids.shape
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
+
+ left_padding = True
+ if batch_size > 1:
+ if _left_padding and not _right_padding:
+ left_padding = True
+ elif not _left_padding and _right_padding:
+ left_padding = False
+ elif not _left_padding and not _right_padding:
+ # both side is 1, so cannot tell
+ left_padding = self.padding_side == "left"
+ else:
+ # invalid attention_mask
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
+
+ # 1. Create a mask to know where special audio tokens are
+ special_audio_token_mask = input_ids == self.config.audio_token_index
+ num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1)
+
+ # In case the Audio model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ attention_mask = attention_mask.to(target_device)
+ input_ids = input_ids.to(target_device)
+ num_audio_tokens = num_audio_tokens.to(target_device)
+ batch_indices, non_audio_indices = torch.where(
+ (input_ids != self.config.audio_token_index) & (attention_mask == 1)
+ )
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged audio-text sequence.
+ # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
+ # `torch.cumsum` computes how each audio token shifts subsequent text token positions.
+ token_placeholder_num = torch.zeros_like(input_ids)
+ token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1
+ token_placeholder_num = token_placeholder_num + 1
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
+ max_token_num = token_placeholder_num.sum(-1).max()
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
+ batch_indices, non_audio_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_audio_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_token_num, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_token_num, dtype=attention_mask.dtype, device=inputs_embeds.device
+ )
+ final_input_ids = torch.full(
+ (batch_size, max_token_num), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
+ )
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
+ final_labels = None
+ if labels is not None:
+ labels = labels.to(target_device)
+ final_labels = torch.full_like(final_attention_mask, self.config.ignore_index).to(torch.long)
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_audio_indices]
+
+ # 5. Fill the embeddings corresponding to the audios. Anything that is still zeros needs filling
+ audio_to_overwrite = torch.full(
+ (batch_size, max_token_num), True, dtype=torch.bool, device=inputs_embeds.device
+ )
+ audio_to_overwrite[batch_indices, text_to_overwrite] = False
+ seq_indices = torch.arange(max_token_num).unsqueeze(0).to(target_device)
+ seq_indices = seq_indices.expand(batch_size, max_token_num)
+
+ if left_padding:
+ # exclude padding on the left
+ max_token_num = max_token_num.to(target_device)
+ val = (max_token_num - seq_indices) <= (
+ token_placeholder_num.sum(-1) - (attention_mask == 0).long().sum(-1)
+ )[:, None]
+ else:
+ # exclude padding on the right
+ val = seq_indices < (token_placeholder_num.sum(-1) - (attention_mask == 0).long().sum(-1))[:, None]
+
+ audio_to_overwrite &= val
+
+ if audio_to_overwrite.sum() != num_audio_tokens.sum():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of audio tokens is {num_special_audio_tokens} while"
+ f" the number of audio given to the model is {num_audios}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[audio_to_overwrite] = (
+ masked_audio_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ )
+ final_attention_mask |= audio_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+
+ return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
+
+ @add_start_docstrings_to_model_forward(QWEN2AUDIO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Qwen2AudioCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ input_features: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ feature_attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Qwen2AudioCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
+
+ >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B")
+
+ >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
+ >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
+ >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate)
+
+ >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Generate the caption in English: Glass is breaking."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ target_device = self.audio_tower.device
+
+ if input_features is not None:
+ input_features = input_features.to(target_device)
+ feature_attention_mask = feature_attention_mask.to(target_device)
+
+ if inputs_embeds is None:
+ # 1. Extract the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and audios
+ if input_features is not None and input_ids.shape[1] != 1:
+ audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
+ feature_attention_mask.sum(-1)
+ )
+ batch_size, _, max_mel_seq_len = input_features.shape
+ max_seq_len = (max_mel_seq_len - 2) // 2 + 1
+ # Create a sequence tensor of shape (batch_size, max_seq_len)
+ seq_range = (
+ torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)
+ .unsqueeze(0)
+ .expand(batch_size, max_seq_len)
+ )
+ lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
+ # Create mask
+ padding_mask = seq_range >= lengths_expand
+
+ audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
+ batch_size, 1, max_seq_len, max_seq_len
+ )
+ audio_attention_mask = audio_attention_mask_.to(
+ dtype=self.audio_tower.conv1.weight.dtype, device=self.audio_tower.conv1.weight.device
+ )
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
+
+ audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask)
+ selected_audio_feature = audio_outputs.last_hidden_state
+ audio_features = self.multi_modal_projector(selected_audio_feature)
+
+ inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
+ audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:]
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2AudioCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ attention_mask=attention_mask,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ input_features=None,
+ attention_mask=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom processing (note: might not be needed, but there are no generation tests running atm)
+
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Here, we get the attention_mask, which was previously stored in the state after _merge_input_ids_with_audio_features.
+ if input_features is not None and kwargs.get("attention_mask") is not None:
+ attention_mask = kwargs["attention_mask"]
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ elif self.config.audio_token_index in input_ids:
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
+ # older attention values, as their corresponding values are not part of the input.
+ if cache_length < past_length and attention_mask is not None:
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ feature_attention_mask = kwargs.get("feature_attention_mask", None)
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "input_features": input_features,
+ "feature_attention_mask": feature_attention_mask,
+ }
+ )
+ return model_inputs
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ num_new_tokens: int = 1,
+ ) -> Dict[str, Any]:
+ # update past_key_values keeping its naming used in model code
+ cache_name, cache = self._extract_past_from_model_output(outputs)
+ model_kwargs[cache_name] = cache
+ if getattr(outputs, "state", None) is not None:
+ model_kwargs["state"] = outputs.state
+
+ # update attention_mask
+ if getattr(outputs, "attention_mask", None) is not None:
+ model_kwargs["attention_mask"] = outputs.attention_mask
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ if not is_encoder_decoder:
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+ else:
+ # update decoder attention mask
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ model_kwargs["decoder_attention_mask"] = torch.cat(
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
+ dim=-1,
+ )
+
+ if model_kwargs.get("use_cache", True):
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
+ else:
+ past_positions = model_kwargs.pop("cache_position")
+ new_positions = torch.arange(
+ past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
+ ).to(past_positions.device)
+ model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
+ return model_kwargs
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
\ No newline at end of file
diff --git a/llava/model/multimodal_encoder/siglip/__init__.py b/llava/model/multimodal_encoder/siglip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98031768cd2ad8f1bc9212d321afea76ac649b5c
--- /dev/null
+++ b/llava/model/multimodal_encoder/siglip/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .modeling_siglip import SiglipVisionModel
diff --git a/llava/model/multimodal_encoder/siglip/modeling_siglip.py b/llava/model/multimodal_encoder/siglip/modeling_siglip.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b5c668e0e0aa06f5667e3d2828e12e99e45783
--- /dev/null
+++ b/llava/model/multimodal_encoder/siglip/modeling_siglip.py
@@ -0,0 +1,1677 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Siglip model."""
+
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn.init import _calculate_fan_in_and_fan_out
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "SiglipConfig"
+_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsequently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
+class SiglipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
+class SiglipTextModelOutput(ModelOutput):
+ """
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ text_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
+class SiglipOutput(ModelOutput):
+ """
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
+ text_model_output(`BaseModelOutputWithPooling`):
+ The output of the [`SiglipTextModel`].
+ vision_model_output(`BaseModelOutputWithPooling`):
+ The output of the [`SiglipVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: torch.FloatTensor = None
+ logits_per_text: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ image_embeds: torch.FloatTensor = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class SiglipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
+ that allows the model to interpolate the pre-trained position encodings such that it can be usable on
+ higher resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+ position_embeddings = self.position_embedding.weight.unsqueeze(0)
+ num_patches = embeddings.shape[1]
+ num_positions = position_embeddings.shape[1]
+ if num_patches == num_positions and height == width:
+ return position_embeddings
+
+ dim = embeddings.shape[-1]
+ height = height // self.patch_size
+ width = width // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ height, width = height + 0.1, width + 0.1
+
+ patch_pos_embed = position_embeddings.reshape(
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
+ _, _, height, width = pixel_values.shape
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
+class SiglipTextEmbeddings(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SiglipFlashAttention2(SiglipAttention):
+ """
+ SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ is_causal = False
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
+ )
+
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class SiglipSdpaAttention(SiglipAttention):
+ """
+ Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ is_causal = False
+
+ # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if self.is_causal and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None
+
+
+SIGLIP_ATTENTION_CLASSES = {
+ "eager": SiglipAttention,
+ "flash_attention_2": SiglipFlashAttention2,
+ "sdpa": SiglipSdpaAttention,
+}
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SiglipEncoderLayer(nn.Module):
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ # Ignore copy
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class SiglipPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SiglipConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "SiglipTextEmbeddings",
+ "SiglipEncoderLayer",
+ "SiglipVisionEmbeddings",
+ "SiglipEncoderLayer",
+ "SiglipMultiheadAttentionPoolingHead",
+ ]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, SiglipVisionEmbeddings):
+ width = (
+ self.config.vision_config.hidden_size
+ if isinstance(self.config, SiglipConfig)
+ else self.config.hidden_size
+ )
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+ elif isinstance(module, nn.Embedding):
+ default_flax_embed_init(module.weight)
+ elif isinstance(module, SiglipAttention):
+ nn.init.xavier_uniform_(module.q_proj.weight)
+ nn.init.xavier_uniform_(module.k_proj.weight)
+ nn.init.xavier_uniform_(module.v_proj.weight)
+ nn.init.xavier_uniform_(module.out_proj.weight)
+ nn.init.zeros_(module.q_proj.bias)
+ nn.init.zeros_(module.k_proj.bias)
+ nn.init.zeros_(module.v_proj.bias)
+ nn.init.zeros_(module.out_proj.bias)
+ elif isinstance(module, SiglipMLP):
+ nn.init.xavier_uniform_(module.fc1.weight)
+ nn.init.xavier_uniform_(module.fc2.weight)
+ nn.init.normal_(module.fc1.bias, std=1e-6)
+ nn.init.normal_(module.fc2.bias, std=1e-6)
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
+ nn.init.xavier_uniform_(module.probe.data)
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
+ nn.init.zeros_(module.attention.in_proj_bias.data)
+ elif isinstance(module, SiglipModel):
+ logit_scale_init = torch.log(torch.tensor(1.0))
+ module.logit_scale.data.fill_(logit_scale_init)
+ module.logit_bias.data.zero_()
+ elif isinstance(module, SiglipForImageClassification):
+ nn.init.normal_(
+ module.classifier.weight,
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
+ )
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+SIGLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+SIGLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+SIGLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
+
+
+class SiglipTextTransformer(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = SiglipTextEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ self.head = nn.Linear(embed_dim, embed_dim)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
+ # expand attention_mask
+ if attention_mask is not None and not self._use_flash_attention_2:
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
+ pooled_output = last_hidden_state[:, -1, :]
+ pooled_output = self.head(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """The text model from SigLIP without any head or projection on top.""",
+ SIGLIP_START_DOCSTRING,
+)
+class SiglipTextModel(SiglipPreTrainedModel):
+ config_class = SiglipTextConfig
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__(config)
+ self.text_model = SiglipTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, SiglipTextModel
+
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class SiglipVisionTransformer(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SiglipVisionEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
+ if self.use_head:
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooler_output = self.head(last_hidden_state) if self.use_head else None
+ if not return_dict:
+ return (last_hidden_state, pooler_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+@add_start_docstrings(
+ """The vision model from SigLIP without any head or projection on top.""",
+ SIGLIP_START_DOCSTRING,
+)
+class SiglipVisionModel(SiglipPreTrainedModel):
+ config_class = SiglipVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SiglipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SiglipVisionModel
+
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+
+@add_start_docstrings(SIGLIP_START_DOCSTRING)
+class SiglipModel(SiglipPreTrainedModel):
+ config_class = SiglipConfig
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, SiglipTextConfig):
+ raise ValueError(
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, SiglipVisionConfig):
+ raise ValueError(
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ # First, initialize the text and vision models with proper attention implementation
+ text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
+ vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
+
+ # Second, get the text and vision submodules (for backward compatibility)
+ self.text_model = text_model.text_model
+ self.vision_model = vision_model.vision_model
+
+ self.logit_scale = nn.Parameter(torch.randn(1))
+ self.logit_bias = nn.Parameter(torch.randn(1))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_outputs[1]
+
+ return pooled_output
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ pooled_output = vision_outputs[1]
+
+ return pooled_output
+
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Union[Tuple, SiglipOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was trained with this
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ image_embeds = vision_outputs[1]
+ text_embeds = text_outputs[1]
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = (
+ torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
+ + self.logit_bias
+ )
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
+ nll = -torch.sum(loglik, dim=-1)
+ loss = nll.mean()
+
+ if not return_dict:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return SiglipOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+@add_start_docstrings(
+ """
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
+ the patch tokens) e.g. for ImageNet.
+ """,
+ SIGLIP_START_DOCSTRING,
+)
+class SiglipForImageClassification(SiglipPreTrainedModel):
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SiglipConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+
+ # Create the vision model with proper attention
+ # and take only vision_model submodule (for backward compatibility)
+ vision_model = SiglipVisionModel._from_config(
+ config.vision_config, attn_implementation=config._attn_implementation
+ )
+ self.vision_model = vision_model.vision_model
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, SiglipForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a `SiglipModel` from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
+ >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the two classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: LABEL_1
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vision_model(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ sequence_output = outputs[0]
+
+ # average pool the patch tokens
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
+ # apply classifier
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/llava/model/multimodal_encoder/siglip_encoder.py b/llava/model/multimodal_encoder/siglip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..74217dd447fcf9a79245aea5e136ca28ded2ae22
--- /dev/null
+++ b/llava/model/multimodal_encoder/siglip_encoder.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+from transformers import PretrainedConfig, SiglipImageProcessor
+
+from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerDynamicS2, VisionTowerS2
+
+from .siglip import SiglipVisionModel
+
+
+class SiglipVisionTower(VisionTower):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
+ super().__init__(model_name_or_path, config)
+ # TODO(ligengl): why pass config here leading to errors?
+ self.vision_tower = SiglipVisionModel.from_pretrained(
+ model_name_or_path,
+ attn_implementation="flash_attention_2",
+ torch_dtype=eval(config.model_dtype),
+ )
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
+ self.is_loaded = True
+
+
+class SiglipVisionTowerS2(VisionTowerS2):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
+ super().__init__(model_name_or_path, config)
+ self.vision_tower = SiglipVisionModel.from_pretrained(
+ model_name_or_path,
+ attn_implementation="flash_attention_2",
+ torch_dtype=eval(config.model_dtype),
+ )
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
+ self.is_loaded = True
+
+
+class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
+ super().__init__(model_name_or_path, config)
+ self.vision_tower = SiglipVisionModel.from_pretrained(
+ model_name_or_path,
+ attn_implementation="flash_attention_2",
+ torch_dtype=eval(config.model_dtype),
+ )
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
+ self.is_loaded = True
diff --git a/llava/model/multimodal_encoder/sound_encoder.py b/llava/model/multimodal_encoder/sound_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..742efd89b6d036636fc7a95e564b7016cd8b2874
--- /dev/null
+++ b/llava/model/multimodal_encoder/sound_encoder.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class SoundTower(nn.Module):
+ def __init__(self, sound_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.sound_tower_name = sound_tower
+ self.cfg_only = None
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+ input_lengths = (input_lengths - 1) // 2 + 1
+ output_lengths = (input_lengths - 2) // 2 + 1
+ return input_lengths, output_lengths
+
+ def forward(self, sounds, mask=None):
+
+ if type(sounds) is list:
+ sound_features = []
+ for sound in sounds:
+ # Calculate attention mask
+ audio_feat_lengths, audio_output_lengths = self._get_feat_extract_output_lengths(mask.sum(-1))
+ # for cases where only one window is there for the audio_clip
+ batch_size, _, max_mel_seq_len = sound.shape
+ max_seq_len = (max_mel_seq_len - 2) // 2 + 1
+ seq_range = (
+ torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)
+ .unsqueeze(0)
+ .expand(batch_size, max_seq_len)
+ )
+ lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
+ padding_mask = seq_range >= lengths_expand
+ audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
+ batch_size, 1, max_seq_len, max_seq_len
+ )
+ audio_attention_mask = audio_attention_mask_.to(
+ dtype=self.sound_tower.conv1.weight.dtype, device=self.sound_tower.conv1.weight.device
+ )
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
+ # Calculate features
+ sound_feature = self.sound_tower(sound, attention_mask=audio_attention_mask)
+ sound_feature = sound_feature.to(sound.dtype)
+ sound_feature = sound_feature.last_hidden_state
+ sound_features.append(sound_feature)
+ else:
+ # Calculate attention mask
+ if len(sounds.shape) == 5:
+ sounds = sounds.squeeze(0).squeeze(1)
+ mask = mask.squeeze(0)
+ audio_feat_lengths, audio_output_lengths = self._get_feat_extract_output_lengths(mask.sum(-1))
+ # for cases where only one window is there for the audio_clip
+ batch_size, _, max_mel_seq_len = sounds.shape
+ max_seq_len = (max_mel_seq_len - 2) // 2 + 1
+ seq_range = (
+ torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)
+ .unsqueeze(0)
+ .expand(batch_size, max_seq_len)
+ )
+ lengths_expand = audio_feat_lengths.expand(batch_size, max_seq_len)
+ padding_mask = seq_range >= lengths_expand
+ audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
+ batch_size, 1, max_seq_len, max_seq_len
+ )
+ audio_attention_mask = audio_attention_mask_.to(
+ dtype=self.sound_tower.conv1.weight.dtype, device=self.sound_tower.conv1.weight.device
+ )
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
+ # Calculate features
+ sound_features = self.sound_tower(sounds, attention_mask=audio_attention_mask)
+ sound_features = sound_features.last_hidden_state
+ sound_features = sound_features.to(sounds.dtype)
+
+ return sound_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.sound_tower.dtype
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.sound_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def device(self):
+ return self.sound_tower.device
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+
diff --git a/llava/model/multimodal_encoder/speech_encoder.py b/llava/model/multimodal_encoder/speech_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ca5d8561250ab652dc6853c33bdf138c754b46
--- /dev/null
+++ b/llava/model/multimodal_encoder/speech_encoder.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class SpeechTower(nn.Module):
+ def __init__(self, speech_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.speech_tower_name = speech_tower
+ self.cfg_only = None
+
+ def forward(self, speeches):
+ if type(speeches) is list:
+ speech_features = []
+ for speech in speeches:
+ speech_feature = self.speech_tower.encoder(speech)
+ speech_feature = speech_feature.last_hidden_state
+ speech_feature = speech_feature.to(speech.dtype)
+ speech_features.append(speech_feature)
+ else:
+ speech_features = self.speech_tower.encoder(speeches)
+ speech_features = speech_features.last_hidden_state
+ speech_features = speech_features.to(speeches.dtype)
+
+ return speech_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.speech_tower.dtype
+
+ @property
+ def device(self):
+ return self.speech_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.speech_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+
diff --git a/llava/model/multimodal_encoder/vision_encoder.py b/llava/model/multimodal_encoder/vision_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90869c439b1764da715cbac9d22bfdf14e24a77
--- /dev/null
+++ b/llava/model/multimodal_encoder/vision_encoder.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from accelerate.hooks import add_hook_to_module
+from einops import rearrange
+from s2wrapper import forward as multiscale_forward
+from transformers import AutoConfig, PreTrainedModel
+from transformers.image_processing_utils import BaseImageProcessor
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+
+
+class VisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ self.cfg_only = None
+
+ def feature_select(self, image_forward_outs):
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+ if self.select_feature == "patch":
+ image_features = image_features[:, 1:]
+ elif self.select_feature == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
+ return image_features
+
+ def _maybe_resize_pos_embeds(
+ self,
+ model: PreTrainedModel,
+ image_processor: BaseImageProcessor,
+ resolution: int = -1,
+ interpolate_mode: str = "linear",
+ ):
+ if resolution in [model.config.image_size, -1]:
+ return
+ print(
+ f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
+ )
+ embeddings = model.vision_model.embeddings
+ patch_size = embeddings.patch_size
+ num_new_tokens = int((resolution // patch_size) ** 2)
+
+ old_embeddings = embeddings.position_embedding
+ match interpolate_mode:
+ case "linear":
+ ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
+ ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
+ import torch
+ import torch.nn as nn
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+ else:
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+ new_embeddings = nn.Embedding(
+ num_new_tokens,
+ old_embedding_dim,
+ dtype=old_embeddings.weight.dtype,
+ device=old_embeddings.weight.device,
+ )
+ mapped_indices = (
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
+ / (num_new_tokens - 1)
+ * (old_num_tokens - 1)
+ )
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
+ if is_deepspeed_zero3_enabled():
+ params = [old_embeddings.weight, new_embeddings.weight]
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
+ ceil_indices, :
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
+ else:
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
+ ceil_indices, :
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
+ new_embeddings.weight.data = interpolated_embeds
+ case _:
+ raise NotImplementedError
+
+ if hasattr(old_embeddings, "_hf_hook"):
+ hook = old_embeddings._hf_hook
+ add_hook_to_module(new_embeddings, hook)
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
+ ## update vision encoder's configurations
+ model.config.image_size = resolution
+ if hasattr(image_processor, "crop_size"):
+ # CLIP vision tower
+ image_processor.crop_size = resolution
+ else:
+ # SIGLIP vision tower
+ assert hasattr(image_processor, "size")
+ image_processor.size = {"height": resolution, "width": resolution}
+ ## TODO define a '_reinitialize' method for VisionTower
+ embeddings.position_embedding = new_embeddings
+ embeddings.image_size = resolution
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
+ embeddings.position_ids = (
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
+ )
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
+ output_hidden_states=True,
+ )
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(
+ images.to(device=self.device, dtype=self.dtype),
+ output_hidden_states=True,
+ )
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
+
+
+class VisionTowerS2(VisionTower):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__(vision_tower, args, delay_load)
+
+ self.scales = list(map(int, args.s2_scales.split(",")))
+ self.scales.sort()
+ self.max_split_size = args.s2_max_split_size
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
+
+ def forward_feature(self, images):
+ image_forward_outs = self.vision_tower(
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
+ )
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_feature = multiscale_forward(
+ self.forward_feature,
+ image.unsqueeze(0),
+ img_sizes=self.scales,
+ max_split_size=self.max_split_size,
+ resize_output_to_idx=self.resize_output_to_scale_idx,
+ )
+ image_features.append(image_feature)
+ else:
+ image_features = multiscale_forward(
+ self.forward_feature,
+ images,
+ img_sizes=self.scales,
+ max_split_size=self.max_split_size,
+ resize_output_to_idx=self.resize_output_to_scale_idx,
+ )
+
+ return image_features
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size * len(self.scales)
+
+
+class VisionTowerDynamicS2(VisionTower):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__(vision_tower, args, delay_load)
+
+ self.scales = list(map(int, args.s2_scales.split(",")))
+ self.scales.sort()
+ self.max_split_size = args.s2_max_split_size
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
+
+ def forward_feature(self, images):
+ image_forward_outs = self.vision_tower(
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
+ )
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+ return image_features
+
+ def forward(self, images):
+ assert type(images) is not list
+
+ image_features = self.forward_feature(images)
+
+ return image_features
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size * len(self.scales)
diff --git a/llava/model/multimodal_encoder/visualize_features.py b/llava/model/multimodal_encoder/visualize_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa17b07dafe9ec2c874b33c3ef00d54c54a1bbc
--- /dev/null
+++ b/llava/model/multimodal_encoder/visualize_features.py
@@ -0,0 +1,357 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+import argparse
+import gc
+import math
+import os
+import random
+from collections import defaultdict
+from typing import Any, Dict, Iterable, List, Tuple
+
+# import cv2
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from datasets import load_dataset, load_dataset_builder
+from datasets.distributed import split_dataset_by_node
+from einops import rearrange
+from PIL import Image
+from torch import nn
+from torch.utils.data import DataLoader
+from torchvision.utils import make_grid
+from tqdm import tqdm
+
+# from common import rank_print, load_model, get_standard_transform, collate
+#
+# try:
+# import wandb
+# except ImportError:
+# wandb = None
+
+
+LAYER_STATS = dict()
+
+
+@torch.inference_mode()
+def main(rank: int = 0, world_size: int = 1):
+ """
+ Computes the RankMe (http://arxiv.org/abs/2210.02885) and LiDAR (http://arxiv.org/abs/2312.04000)
+ estimates of the rank of the produced embeddings. While RADIO doesn't train in a multi-view setting
+ which is an assumption of LiDAR, the metric does integrate an important concept of the invariance of the
+ summary features to different view/augmentations of the same image.
+ """
+
+ local_rank = rank % torch.cuda.device_count()
+ torch.cuda.set_device(local_rank)
+ cv2.setNumThreads(1)
+
+ device = torch.device("cuda", local_rank)
+ parser = argparse.ArgumentParser(description="Compute SSL embedding rank estimates")
+ parser.add_argument("-v", "--model-version", default="radio_v2", help="Which radio model to load.")
+ parser.add_argument("-d", "--dataset", default="imagenet-1k", help="The name of the dataset to classify")
+ parser.add_argument("--split", default="validation", help="The dataset split to use.")
+ parser.add_argument("-n", default=10, type=int, help="The number of samples to load")
+ parser.add_argument(
+ "-r",
+ "--resolution",
+ nargs="+",
+ type=int,
+ default=None,
+ help="The input image resolution."
+ " If one value is specified, the shortest dimension is resized to this."
+ " If two, the image is center cropped."
+ " If not specified, center cropped 378px is used."
+ " Default: The RADIO model's preferred resolution.",
+ )
+ parser.add_argument(
+ "--resize-multiple",
+ type=int,
+ default=None,
+ help="Resize images with dimensions a multiple of this value."
+ " This should be equal to the patch size of a ViT (e.g. RADIOv1)",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=16,
+ help="The batch size. If the input is variable sized, then this argument becomes a maximum.",
+ )
+ parser.add_argument("--workers", default=8, type=int, help="Number of loader workers to use")
+ parser.add_argument(
+ "--vitdet-window-size", default=None, type=int, help="Enable ViTDet at the specific window size"
+ )
+ parser.add_argument("--output-dir", default="vis_denoise", type=str)
+ parser.add_argument("--adaptor-name", default=None, type=str, help="Generate features from a teacher adaptor")
+
+ args, _ = parser.parse_known_args()
+
+ torch.manual_seed(42 + rank)
+ np.random.seed(42 + rank)
+ random.seed(42 + rank)
+
+ rank_print("Loading model...")
+ model, preprocessor, info = load_model(
+ args.model_version, vitdet_window_size=args.vitdet_window_size, adaptor_name=args.adaptor_name
+ )
+ model.to(device=device).eval()
+ if isinstance(preprocessor, nn.Module):
+ preprocessor.to(device).eval()
+ rank_print("Done")
+
+ rank_print("Loading dataset...")
+ ds_builder = load_dataset_builder(args.dataset, trust_remote_code=True)
+
+ if args.resolution is None:
+ args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width)
+
+ patch_size = model.patch_size
+
+ if args.resize_multiple is None:
+ args.resize_multiple = getattr(model, "min_resolution_step", model.patch_size)
+
+ transform = get_standard_transform(args.resolution, args.resize_multiple)
+ dataset = ds_builder.as_dataset(split=args.split)
+ dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers))
+ dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
+ dataset = dataset.map(
+ lambda ex: dict(image=transform(ex["image"]), label=torch.as_tensor(ex["label"], dtype=torch.int64))
+ )
+
+ loader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.workers,
+ collate_fn=collate,
+ pin_memory=args.workers > 0,
+ drop_last=False,
+ )
+ rank_print("Done")
+ rank_print(f"Description: {ds_builder.info.description}")
+
+ dirs = dict(
+ orig=os.path.join(args.output_dir, "orig"),
+ viz=os.path.join(args.output_dir, "viz"),
+ sbs=os.path.join(args.output_dir, "sbs"),
+ )
+
+ for d in dirs.values():
+ os.makedirs(d, exist_ok=True)
+
+ ctr = 0
+ for batches in loader:
+ if ctr >= args.n:
+ break
+
+ for images, _ in batches:
+ images = images.to(device=device, non_blocking=True)
+
+ all_feat = []
+ with torch.autocast(device.type, dtype=torch.bfloat16):
+ p_images = preprocessor(images)
+
+ output = model(p_images)
+ if args.adaptor_name:
+ all_feat = [
+ output["backbone"].features,
+ output[args.adaptor_name].features,
+ ]
+ else:
+ all_feat = [output[1]]
+
+ all_feat = torch.stack(all_feat, dim=1)
+
+ num_rows = images.shape[-2] // patch_size
+ num_cols = images.shape[-1] // patch_size
+
+ all_feat = rearrange(all_feat, "b m (h w) c -> b m h w c", h=num_rows, w=num_cols).float()
+
+ for i, feats in enumerate(all_feat):
+ colored = []
+ for features in feats:
+ color = get_pca_map(features, images.shape[-2:])
+ colored.append(color)
+
+ orig = cv2.cvtColor(images[i].permute(1, 2, 0).cpu().numpy(), cv2.COLOR_RGB2BGR)
+
+ cv2.imwrite(f'{dirs["orig"]}/vis_{ctr}.jpg', orig * 255)
+ cv2.imwrite(f'{dirs["viz"]}/vis_{ctr}.jpg', colored[-1] * 255)
+
+ op = np.concatenate([orig] + colored, axis=1) * 255
+
+ cv2.imwrite(f'{dirs["sbs"]}/vis_{ctr}.jpg', op)
+ ctr += 1
+
+
+def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
+ # features: (N, C)
+ # m: a hyperparam controlling how many std dev outside for outliers
+ assert len(features.shape) == 2, "features should be (N, C)"
+ reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
+ colors = features @ reduction_mat
+ if remove_first_component:
+ colors_min = colors.min(dim=0).values
+ colors_max = colors.max(dim=0).values
+ tmp_colors = (colors - colors_min) / (colors_max - colors_min)
+ fg_mask = tmp_colors[..., 0] < 0.2
+ reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
+ colors = features @ reduction_mat
+ else:
+ fg_mask = torch.ones_like(colors[:, 0]).bool()
+ d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
+ mdev = torch.median(d, dim=0).values
+ s = d / mdev
+ try:
+ rins = colors[fg_mask][s[:, 0] < m, 0]
+ gins = colors[fg_mask][s[:, 1] < m, 1]
+ bins = colors[fg_mask][s[:, 2] < m, 2]
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
+ except:
+ rins = colors
+ gins = colors
+ bins = colors
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
+
+ return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
+
+
+def get_pca_map(
+ feature_map: torch.Tensor,
+ img_size,
+ interpolation="bicubic",
+ return_pca_stats=False,
+ pca_stats=None,
+):
+ """
+ feature_map: (1, h, w, C) is the feature map of a single image.
+ """
+ feature_map = feature_map.float()
+ if feature_map.shape[0] != 1:
+ # make it (1, h, w, C)
+ feature_map = feature_map[None]
+ if pca_stats is None:
+ reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
+ else:
+ reduct_mat, color_min, color_max = pca_stats
+ pca_color = feature_map @ reduct_mat
+ pca_color = (pca_color - color_min) / (color_max - color_min)
+ pca_color = F.interpolate(
+ pca_color.permute(0, 3, 1, 2),
+ size=img_size,
+ mode=interpolation,
+ ).permute(0, 2, 3, 1)
+ pca_color = pca_color.clamp(0, 1)
+ pca_color = pca_color.cpu().numpy().squeeze(0)
+ if return_pca_stats:
+ return pca_color, (reduct_mat, color_min, color_max)
+ return pca_color
+
+
+def get_scale_map(
+ scalar_map: torch.Tensor,
+ img_size,
+ interpolation="nearest",
+):
+ """
+ scalar_map: (1, h, w, C) is the feature map of a single image.
+ """
+ if scalar_map.shape[0] != 1:
+ scalar_map = scalar_map[None]
+ scalar_map = (scalar_map - scalar_map.min()) / (scalar_map.max() - scalar_map.min() + 1e-6)
+ scalar_map = F.interpolate(
+ scalar_map.permute(0, 3, 1, 2),
+ size=img_size,
+ mode=interpolation,
+ ).permute(0, 2, 3, 1)
+ # cmap = plt.get_cmap("viridis")
+ # scalar_map = cmap(scalar_map)[..., :3]
+ # make it 3 channels
+ scalar_map = torch.cat([scalar_map] * 3, dim=-1)
+ scalar_map = scalar_map.cpu().numpy().squeeze(0)
+ return scalar_map
+
+
+def get_similarity_map(features: torch.Tensor, img_size=(224, 224)):
+ """
+ compute the similarity map of the central patch to the rest of the image
+ """
+ assert len(features.shape) == 4, "features should be (1, C, H, W)"
+ H, W, C = features.shape[1:]
+ center_patch_feature = features[0, H // 2, W // 2, :]
+ center_patch_feature_normalized = center_patch_feature / center_patch_feature.norm()
+ center_patch_feature_normalized = center_patch_feature_normalized.unsqueeze(1)
+ # Reshape and normalize the entire feature tensor
+ features_flat = features.view(-1, C)
+ features_normalized = features_flat / features_flat.norm(dim=1, keepdim=True)
+
+ similarity_map_flat = features_normalized @ center_patch_feature_normalized
+ # Reshape the flat similarity map back to the spatial dimensions (H, W)
+ similarity_map = similarity_map_flat.view(H, W)
+
+ # Normalize the similarity map to be in the range [0, 1] for visualization
+ similarity_map = (similarity_map - similarity_map.min()) / (similarity_map.max() - similarity_map.min())
+ # we don't want the center patch to be the most similar
+ similarity_map[H // 2, W // 2] = -1.0
+ similarity_map = (
+ F.interpolate(
+ similarity_map.unsqueeze(0).unsqueeze(0),
+ size=img_size,
+ mode="bilinear",
+ )
+ .squeeze(0)
+ .squeeze(0)
+ )
+
+ similarity_map_np = similarity_map.cpu().numpy()
+ negative_mask = similarity_map_np < 0
+
+ colormap = plt.get_cmap("turbo")
+
+ # Apply the colormap directly to the normalized similarity map and multiply by 255 to get RGB values
+ similarity_map_rgb = colormap(similarity_map_np)[..., :3]
+ similarity_map_rgb[negative_mask] = [1.0, 0.0, 0.0]
+ return similarity_map_rgb
+
+
+def get_cluster_map(
+ feature_map: torch.Tensor,
+ img_size,
+ num_clusters=10,
+) -> torch.Tensor:
+ kmeans = KMeans(n_clusters=num_clusters, distance=CosineSimilarity, verbose=False)
+ if feature_map.shape[0] != 1:
+ # make it (1, h, w, C)
+ feature_map = feature_map[None]
+ labels = kmeans.fit_predict(feature_map.reshape(1, -1, feature_map.shape[-1])).float()
+ labels = (
+ F.interpolate(labels.reshape(1, *feature_map.shape[:-1]), size=img_size, mode="nearest").squeeze().cpu().numpy()
+ ).astype(int)
+ cmap = plt.get_cmap("rainbow", num_clusters)
+ cluster_map = cmap(labels)[..., :3]
+ return cluster_map.reshape(img_size[0], img_size[1], 3)
+
+
+if __name__ == "__main__":
+ rank = 0
+ world_size = 1
+
+ # if 'WORLD_SIZE' in os.environ:
+ # dist.init_process_group(backend='nccl')
+ # rank = dist.get_rank()
+ # world_size = dist.get_world_size()
+
+ main(rank, world_size)
diff --git a/llava/model/multimodal_encoder/whisper_encoder.py b/llava/model/multimodal_encoder/whisper_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb34b164f2a8ad3cb343f5b3fbb434c2b652dc8a
--- /dev/null
+++ b/llava/model/multimodal_encoder/whisper_encoder.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+import torch
+from transformers import WhisperProcessor, WhisperModel
+from transformers import PretrainedConfig
+
+from llava.model.multimodal_encoder.speech_encoder import SpeechTower
+
+class WhisperSpeechTower(SpeechTower):
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+ super().__init__(model_name_or_path, config)
+ self.speech_tower = WhisperModel.from_pretrained(model_name_or_path)
+ self.is_loaded = True
diff --git a/llava/model/multimodal_projector/base_projector.py b/llava/model/multimodal_projector/base_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce279bddd0c6f47c206e9c83f65bd08c7ba7495a
--- /dev/null
+++ b/llava/model/multimodal_projector/base_projector.py
@@ -0,0 +1,234 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": "identity"}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+class DownSampleBlock(nn.Module):
+ def forward(self, x):
+ vit_embeds = x
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.flat_square(vit_embeds)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ return vit_embeds
+
+ def flat_square(self, x):
+ n, w, h, c = x.size()
+ if w % 2 == 1:
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
+ n, w, h, c = x.size()
+ if h % 2 == 1:
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
+ n, w, h, c = x.size()
+ x = x.contiguous()
+ x = x.view(n, w, int(h / 2), int(c * 2))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+
+class DownSample2x2BlockFix(nn.Module):
+ def forward(self, x):
+ vit_embeds = x
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = flat_square_2x2(vit_embeds)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ return vit_embeds
+
+
+def flat_square_2x2(x):
+ n, w, h, c = x.size()
+ if w % 2 == 1:
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
+ n, w, h, c = x.size()
+ x = x.contiguous()
+ if h % 2 == 1:
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
+ n, w, h, c = x.size()
+ x = x.view(n, w, int(h / 2), int(c * 2))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+
+class DownSample3x3BlockFix(nn.Module):
+ def forward(self, x):
+ vit_embeds = x
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = flat_square_3x3(vit_embeds)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ return vit_embeds
+
+
+def flat_square_3x3(x):
+ n, w, h, c = x.size()
+ if w % 3 != 0:
+ x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
+ n, w, h, c = x.size()
+ x = x.contiguous()
+ if h % 3 != 0:
+ x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
+ n, w, h, c = x.size()
+ x = x.view(n, w, int(h / 3), int(c * 3))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+
+class MultimodalProjectorConfig(PretrainedConfig):
+ model_type = "v2l_projector"
+
+ def __init__(self, mm_projector_type: str = None, **kwargs):
+ super().__init__()
+ self.mm_projector_type = mm_projector_type
+
+
+class MultimodalProjector(PreTrainedModel):
+ config_class = MultimodalProjectorConfig
+
+ def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
+ super().__init__(mm_projector_cfg)
+ mm_projector_type = mm_projector_cfg.mm_projector_type
+ self.downsample_rate = 1
+ if mm_projector_type == "identity":
+ self.layers = IdentityMap()
+ elif mm_projector_type == "linear":
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
+ elif mm_projector_type == "mlp_downsample":
+ self.layers = nn.Sequential(
+ DownSampleBlock(),
+ nn.LayerNorm(config.mm_hidden_size * 4),
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ self.downsample_rate = 2
+ elif mm_projector_type == "mlp_downsample_2x2_fix":
+ self.layers = nn.Sequential(
+ DownSample2x2BlockFix(),
+ nn.LayerNorm(config.mm_hidden_size * 4),
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ self.downsample_rate = 2
+ elif mm_projector_type == "mlp_downsample_3x3_fix":
+ self.layers = nn.Sequential(
+ DownSample3x3BlockFix(),
+ nn.LayerNorm(config.mm_hidden_size * 9),
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size * 3),
+ nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ self.downsample_rate = 3
+ elif mm_projector_type == "mlp_downsample_3x3_s2":
+ self.layers = nn.Sequential(
+ DownSample3x3BlockFix(),
+ nn.LayerNorm(config.mm_hidden_size * 9),
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size * 3),
+ nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size),
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size // 3),
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ elif mm_projector_type == "mlp_downsample_3x3_s2_new":
+ self.layers = nn.Sequential(
+ DownSample3x3BlockFix(),
+ nn.LayerNorm(config.mm_hidden_size * 9),
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size * 4),
+ nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size * 2),
+ nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size),
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
+ nn.GELU(),
+ nn.LayerNorm(config.mm_hidden_size // 3),
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ else:
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ self.layers = nn.Sequential(*modules)
+ else:
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
+
+ def forward(self, x, *args, **kwargs):
+ return self.layers(x)
+
+
+AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
+AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
diff --git a/llava/model/multimodal_projector/builder.py b/llava/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..79792f0fe00ded81734a87ea247e8e0096667a16
--- /dev/null
+++ b/llava/model/multimodal_projector/builder.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+import os
+
+import torch
+from transformers import PretrainedConfig, PreTrainedModel
+
+from .base_projector import MultimodalProjector, MultimodalProjectorConfig
+from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig
+from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig
+
+
+def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
+ if model_type_or_path is None:
+ return None
+
+ ## load from pretrained model
+ if config.resume_path:
+ assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!"
+ return SpeechMultimodalProjector.from_pretrained(model_type_or_path, config, torch_dtype=eval(config.model_dtype))
+ ## build from scratch
+ else:
+ print("WARNING: Building speech multimodal projector from scratch!")
+ speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path)
+ speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype))
+ return speech_mm_projector
+
+def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
+ if model_type_or_path is None:
+ return None
+
+ ## load from pretrained model
+ if config.resume_path:
+ print(config.resume_path)
+ assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!"
+ return SoundMultimodalProjector.from_pretrained(model_type_or_path, config, torch_dtype=eval(config.model_dtype))
+ # build from scratch
+ else:
+ print("WARNING: Building sound multimodal projector from scratch!")
+ sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path)
+ sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(eval(config.model_dtype))
+ return sound_mm_projector
diff --git a/llava/model/multimodal_projector/sound_base_projector.py b/llava/model/multimodal_projector/sound_base_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d96b3f7ab5a3f2ad4b4a18a4982dc751588155
--- /dev/null
+++ b/llava/model/multimodal_projector/sound_base_projector.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class SoundMultimodalProjectorConfig(PretrainedConfig):
+ model_type = "sound_mm_projector"
+
+ def __init__(self, sound_mm_projector_type: str = None, **kwargs):
+ super().__init__()
+ self.sound_mm_projector_type = sound_mm_projector_type
+
+
+class SoundMultimodalProjector(PreTrainedModel):
+ config_class = SoundMultimodalProjectorConfig
+
+ def __init__(self, sound_mm_projector_cfg: SoundMultimodalProjectorConfig, config: PretrainedConfig):
+ super().__init__(sound_mm_projector_cfg)
+ # sound_mm_projector_type = sound_mm_projector_cfg.sound_mm_projector_type
+ sound_mm_projector_type = "mlp"
+
+ if sound_mm_projector_type == "mlp":
+ self.layers = nn.Sequential(
+ nn.Linear(config.sound_hidden_size, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ else:
+ raise ValueError(f"Unknown projector type: {sound_mm_projector_type}")
+
+ def forward(self, x, *args, **kwargs):
+ return self.layers(x)
+
+
+AutoConfig.register("sound_mm_projector", SoundMultimodalProjectorConfig)
+AutoModel.register(SoundMultimodalProjectorConfig, SoundMultimodalProjector)
diff --git a/llava/model/multimodal_projector/speech_base_projector.py b/llava/model/multimodal_projector/speech_base_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9856332e8e64cc95218d278deb8e7af8888a60d
--- /dev/null
+++ b/llava/model/multimodal_projector/speech_base_projector.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class SpeechMultimodalProjectorConfig(PretrainedConfig):
+ model_type = "speech_mm_projector"
+
+ def __init__(self, speech_mm_projector_type: str = None, **kwargs):
+ super().__init__()
+ self.speech_mm_projector_type = speech_mm_projector_type
+
+
+class SpeechMultimodalProjector(PreTrainedModel):
+ config_class = SpeechMultimodalProjectorConfig
+
+ def __init__(self, speech_mm_projector_cfg: SpeechMultimodalProjectorConfig, config: PretrainedConfig):
+ super().__init__(speech_mm_projector_cfg)
+ # speech_mm_projector_type = speech_mm_projector_cfg.speech_mm_projector_type
+ speech_mm_projector_type = "mlp"
+
+ if speech_mm_projector_type == "mlp":
+ self.conv = nn.Conv1d(config.speech_hidden_size, config.speech_hidden_size, kernel_size=2, stride=2)
+ self.layers = nn.Sequential(
+ nn.Linear(config.speech_hidden_size, config.hidden_size),
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+ else:
+ raise ValueError(f"Unknown projector type: {speech_mm_projector_type}")
+
+ def forward(self, x, *args, **kwargs):
+ x = self.conv(x.transpose(1,2)).transpose(1,2)
+ return self.layers(x)
+
+
+AutoConfig.register("speech_mm_projector", SpeechMultimodalProjectorConfig)
+AutoModel.register(SpeechMultimodalProjectorConfig, SpeechMultimodalProjector)
diff --git a/llava/model/qfunction.py b/llava/model/qfunction.py
new file mode 100644
index 0000000000000000000000000000000000000000..196ae8f7f614004a48ab90e8fcb28dd44b628afd
--- /dev/null
+++ b/llava/model/qfunction.py
@@ -0,0 +1,280 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import re
+
+import torch
+
+try:
+ from .FloatPointQuantizeTorch import *
+ from .FloatPointQuantizeTriton import *
+except:
+ from FloatPointQuantizeTorch import *
+ from FloatPointQuantizeTriton import *
+
+
+def block_cut(input, row_block, column_block, pad_block=False):
+ # print(input.shape)
+ original_shape = input.shape
+ # input tensor shape is M * N
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ input = torch.nn.functional.pad(
+ input, (0, col_pad, 0, row_pad), "constant", 0
+ ) # refer to torch's doc to see why
+ M, N = input.shape[0], input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
+ assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
+ # print(input.shape)
+ input = (
+ input.reshape(row_num, row_block, column_num, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * column_num, row_block, column_block)
+ )
+ # print(input.shape)
+ return input
+
+
+def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
+ if len(origin_input.shape) > 2:
+ flatten_input = origin_input.reshape(-1, origin_input.shape[2])
+ elif len(origin_input.shape) == 2:
+ flatten_input = origin_input
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut")
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
+ M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ input = (
+ input.reshape(row_num, column_num, row_block, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * row_block, column_num * column_block)
+ )
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+ input = input[:M, :N]
+
+ if len(origin_input.shape) > 2:
+ input = input.reshape(origin_input.shape)
+ elif len(origin_input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block reshape")
+
+ return input
+
+
+def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
+ Binput = block_cut(input, row_block, column_block)
+ Binput = Binput.to(torch.float32)
+
+ for n in range(Binput.shape[0]):
+ unique_values = len(torch.unique(Binput[n, :, :]))
+ if unique_values > 256:
+ if necessary:
+ raise ValueError(f"{layer_type} contains more than 256 unique values.")
+ else:
+ return False
+ return True
+
+
+def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
+ Quant_fn = SymmQuantizer
+ return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)
+
+
+def extract_bit(string):
+ match = re.match(r"INT(\d+)", string)
+ if match:
+ return "integer", int(match.group(1)), None
+ match = re.match(r"E(\d+)M(\d+)", string)
+ if match:
+ Ebit, Mbit = int(match.group(1)), int(match.group(2))
+ if Ebit == 1:
+ return "integer", Mbit + 1, None
+ if Mbit == 0:
+ return "floatExM0", int(match.group(1)), 0
+ return "floatExMy", int(match.group(1)), int(match.group(2))
+ match = re.match(r"DE(\d+)", string)
+ if match:
+ return "Dynamic", int(match.group(1)), None
+ match = re.match(r"ZeroD(\d+)", string)
+ if match:
+ return "ZeroDynamic", int(match.group(1)), None
+ raise ValueError(f"{string} data format is not supported")
+
+
+class SymmQuantizer(torch.autograd.function.InplaceFunction):
+ @staticmethod
+ def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
+ with torch.no_grad():
+ absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon
+
+ if bits == "100" or not apply_quantize:
+ return input, input, torch.ones_like(absmax_per_block)
+ elif bits == "FP32":
+ return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
+ elif bits == "FP16":
+ return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
+ elif bits == "BF16":
+ return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
+ else:
+ QuantType, bit1, bit2 = extract_bit(bits)
+ if not symm:
+ bit1 = bit1 + 1 # pretend to be asymmtric
+
+ if QuantType == "integer":
+ Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
+ elif QuantType == "floatExMy":
+ Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
+ 2 ** (2 ** (bit1 - 1))
+ )
+ if bit1 == 4 and bit2 == 3:
+ Qn, Qp = -448, 448
+ if bit1 == 5 and bit2 == 2:
+ Qn, Qp = -57344, 57344
+ elif QuantType == "floatExM0":
+ Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
+ elif QuantType == "Dynamic":
+ Qn, Qp = -1, 1
+ elif QuantType == "ZeroDynamic":
+ Qn, Qp = -1, 1
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+ scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
+ scale_per_block = scale_per_block.to(input)
+
+ Qinput = input / scale_per_block
+
+ if QuantType == "integer":
+ if stochastic:
+ noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
+ Qinput.add_(noise)
+ Qinput.clamp_(Qn, Qp).round_()
+ elif QuantType == "floatExMy":
+ # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
+ Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
+ elif QuantType == "floatExM0":
+ Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
+ elif QuantType == "Dynamic":
+ Qinput = Dynamic_quantize_torch(Qinput, bit1, stochastic)
+ elif QuantType == "ZeroDynamic":
+ Qinput = ZeroDynamic_quantize_torch(Qinput, bit1, stochastic)
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+
+ RQinput = Qinput * scale_per_block
+ # print(f'Layer Name: {layer_name}: {input, Qinput, scale_per_block, absmax_per_block}', file=open('debug.txt', 'a'))
+
+ if input.dtype != Qinput.dtype:
+ print(
+ f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
+ file=open("debug.txt", "a"),
+ )
+ import IPython
+
+ IPython.embed()
+ return RQinput, Qinput, scale_per_block
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None, None, None, None, None
+
+
+if __name__ == "__main__":
+ import time
+
+ torch.manual_seed(0)
+ X = (torch.rand(2048, 4096).cuda() - 0.5) * 1000
+
+ X = X.to(torch.bfloat16)
+
+ B_X = block_cut(X, -1, -1)
+ RQ_X, Q_X, S_X = block_quant(
+ B_X, symm=True, bits="E2M0", stochastic=True, epsilon=1e-14, apply_quantize=True, layer_name=""
+ )
+ RQ_X = block_reshape(RQ_X, X, -1, -1)
+ Q_X = block_reshape(Q_X, X, -1, -1)
+
+ # start = time.time()
+ # for _ in range(10000):
+ # B_X = block_cut(X, 2, 2)
+ # RQ_X, Q_X, S_X = block_quant(B_X, symm=True, bits="E5M2", stochastic=True, epsilon=1e-14, apply_quantize=True, layer_name='')
+ # RQ_X = block_reshape(RQ_X, X, 2, 2)
+ # Q_X = block_reshape(Q_X, X, 2, 2)
+ # torch.cuda.synchronize()
+
+ # end = time.time()
+ # print(f"Time cost: {end - start}")
+
+ print(X.dtype)
+ import IPython
+
+ IPython.embed()
+
+ # X = torch.tensor([
+ # [1.1, 1, 1, 2, 2, 2, 3, 3, 3.01],
+ # [1.2, 1, 1, 2, 2, 2, 3, 3, 3.02],
+ # [4.1, 4, 4, 5, 5, 5, 6, 6, 6.01],
+ # [4.2, 4, 4, 5, 5, 5, 6, 6, 6.02],
+ # [7.1, 7, 7, 8, 8, 8, 9, 9, 9.01],
+ # [7.2, 7, 7, 8, 8, 8, 9, 9, 9.02],
+ # ])
+
+ # B_X = block_cut(X, 2, 3)
+ # print(B_X)
+ # print(B_X.shape)
+ # Q_X = block_quant(B_X, symm=True, bits=2, stochastic=False)
+ # print(Q_X)
+ # print(Q_X.shape)
diff --git a/llava/model/qlinear_te.py b/llava/model/qlinear_te.py
new file mode 100644
index 0000000000000000000000000000000000000000..b290b053fb51a7ece2ead906139042898dd697b2
--- /dev/null
+++ b/llava/model/qlinear_te.py
@@ -0,0 +1,231 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+import time
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, InplaceFunction
+from torch.cuda import amp
+
+from .language_model.configuration_quantize import QuantizationConfig
+from .qfunction import block_cut, block_quant, block_reshape
+from .qutils import quant_get_local_rank
+from .realquantize.division_transpose import fp8_division_transpose
+from .realquantize.linear import fp8_linear_backward, fp8_linear_forward
+from .realquantize.quantize_and_transpose import fp8_quantize_and_transpose
+
+
+class QLinearTE(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0):
+ super().__init__(in_features, out_features, bias, device)
+ try: # TODO: remove this try except (llama & qwen2)
+ self.args = QuantizationConfig(**deepcopy(args))
+ except:
+ self.args = deepcopy(args)
+
+ self.apply_quantize = min(self.weight.shape[0], self.weight.shape[1]) >= 3584
+
+ if quant_get_local_rank() == 0:
+ if self.apply_quantize:
+ print(f"[qlinear debug] Apply QLinear, {layer_idx}")
+ else:
+ print(f"[qlinear debug] Don't QLinear, {layer_idx} since the weight is too small: {self.weight.shape}")
+ self.layer_idx = layer_idx
+ self.layer_name = None
+
+ def forward(self, Input):
+ # if torch.randn(1) < 0.01:
+ # print(Input.shape, self.weight.shape)
+ if self.training and self.apply_quantize:
+ # if False:
+ output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name)
+ else:
+ output = F.linear(Input, self.weight, self.bias)
+
+ return output
+
+
+# if int(os.environ.get("LOCAL_RANK")) == 0:
+# import IPython
+# IPython.embed()
+# else:
+# import time
+# time.sleep(1000)
+
+# class QuantLinearTE(Function):
+# @staticmethod
+# def forward(ctx, input, weight, bias, args, layer_type):
+# ctx.saved = input, weight, bias, args, layer_type
+# return F.linear(input, weight, bias)
+
+# @staticmethod
+# def backward(ctx, grad_output):
+# input, weight, bias, args, layer_type = ctx.saved
+
+# C_in = input.shape[-1]
+# C_out = grad_output.shape[-1]
+
+# grad_output_flatten = grad_output.reshape(-1, C_out)
+# input_flatten = input.reshape(-1, C_in)
+
+# if grad_output_flatten.dtype == input_flatten.dtype:
+# grad_weight = grad_output_flatten.t().mm(input_flatten)
+# else:
+# grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+# if grad_output_flatten.dtype == weight.dtype:
+# grad_input = grad_output_flatten.mm(weight)
+# else:
+# grad_input = grad_output_flatten.float().mm(weight)
+
+# if bias is not None:
+# grad_bias = grad_output_flatten.sum(0)
+# else:
+# grad_bias = None
+
+# grad_input_transform = grad_input.reshape(input.size())
+
+# return grad_input_transform, grad_weight, grad_bias, None, None
+
+
+class QuantLinearTE(Function):
+ @staticmethod
+ @amp.custom_fwd(cast_inputs=torch.bfloat16)
+ def forward(ctx, input, weight, bias, args, layer_name):
+
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit)
+ Qinput, Iscale, Qinput_t = fp8_quantize_and_transpose(input, 16, args.fabit, transpose_output_2d=True)
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit)
+ Qweight, Wscale, Qweight_t = fp8_quantize_and_transpose(weight, 16, args.fwbit, transpose_output_2d=True)
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name
+ fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias)
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ output = F.linear(input, weight, bias)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}"
+ )
+
+ return fc_output
+
+ @staticmethod
+ @amp.custom_bwd
+ def backward(ctx, grad_output):
+ Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
+
+ time_bench = os.getenv("TIME_BENCH")
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False)
+ Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_and_transpose(
+ grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True
+ )
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ grad_input, grad_weight = fp8_linear_backward(
+ Qinput_t,
+ Iscale,
+ Qgrad_output,
+ Gscale,
+ Qgrad_output_t,
+ Qweight_t,
+ Wscale,
+ 16,
+ bias,
+ stochastic=False,
+ dgrad_quantize=False,
+ )
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ if bias is not None:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+ else:
+ grad_bias = None
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+
+ # ========== BF16 ==========
+ C_in = Qinput_t.shape[0]
+ C_out = grad_output.shape[-1]
+ grad_output_flatten = grad_output.reshape(-1, C_out)
+ input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16)
+ weight = Qweight_t.t().to(torch.bfloat16)
+
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ if grad_output_flatten.dtype == input_flatten.dtype:
+ _grad_weight = grad_output_flatten.t().mm(input_flatten)
+ else:
+ _grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+ if grad_output_flatten.dtype == weight.dtype:
+ _grad_input = grad_output_flatten.mm(weight)
+ else:
+ _grad_input = grad_output_flatten.float().mm(weight)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}"
+ )
+
+ return grad_input, grad_weight, grad_bias, None, None
diff --git a/llava/model/quantization/FloatPointQuantizeTorch.py b/llava/model/quantization/FloatPointQuantizeTorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..22e08223ba4cd95ccc6abe7a672201cd545505c5
--- /dev/null
+++ b/llava/model/quantization/FloatPointQuantizeTorch.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+
+import torch
+
+
+def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Mhigh = -(2 ** (e_bit - 1)), 2**m_bit - 1
+ expo = torch.floor(torch.log2(x_abs))
+ expo = torch.clamp(expo, min=Elow)
+ mant = x_abs / torch.exp2(expo)
+
+ mant_int = torch.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * (Mhigh + 1)
+ if stochastic:
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
+ mant_frac.add_(noise)
+ mant_frac = torch.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / (Mhigh + 1)
+ y = sign * (2**expo) * mant_q
+ y = y.to(x)
+ return y
+
+
+def floatExM0_quantize_torch(x, e_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh = -(2 ** (e_bit - 1)), 2 ** (e_bit - 1)
+ expo = torch.log2(x_abs)
+ if stochastic:
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
+ expo.add(noise)
+ log_bias = math.log2(4 / 3) - 1 / 2
+ expo.add(torch.ones_like(expo) * log_bias)
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
+ expo = torch.round(expo)
+
+ y = sign * (2**expo)
+ y = y.to(x)
+ return y
+
+
+def Dynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
+
+ y = y * zero_mask
+ y = y.to(x)
+ return y
+
+
+def ZeroDynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ y = y.to(x)
+ return y
diff --git a/llava/model/quantization/FloatPointQuantizeTriton.py b/llava/model/quantization/FloatPointQuantizeTriton.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfb5839b79e530b6b0b585100e47c1627de9f0d
--- /dev/null
+++ b/llava/model/quantization/FloatPointQuantizeTriton.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+import struct
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+
+def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
+ n_elements = x.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.zeros_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+
+ return y
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_quantize_kernel(
+ x_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ # mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_stochastic_quantize_kernel(
+ x_ptr,
+ noise_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ noise = tl.load(noise_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
diff --git a/llava/model/quantization/QAct.py b/llava/model/quantization/QAct.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ed63b1a96c567e7b8975ee9fa3036e465f1e81
--- /dev/null
+++ b/llava/model/quantization/QAct.py
@@ -0,0 +1,517 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QAct_FPout(nn.Identity):
+ def __init__(self, args, normalize_before=False, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.normalize_before = normalize_before
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qact_config, f"{layer_type} not in qact_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qact_config[layer_type])
+ self.apply_quantize_f, self.apply_quantize_b = self.apply_quantize, self.apply_quantize
+
+ self.refine_rowcol_blocksize()
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply-f": self.apply_quantize_f, "apply-b": self.apply_quantize_b},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {
+ "row-f": self.args.row_blocksize_f,
+ "col-f": self.args.col_blocksize_f,
+ "row-b": self.args.row_blocksize_b,
+ "col-b": self.args.col_blocksize_b,
+ },
+ )
+ if quant_get_local_rank() == 0:
+ print(quantize_flag)
+
+ def refine_rowcol_blocksize(self):
+ self.args.row_blocksize_f, self.args.col_blocksize_f = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_b, self.args.col_blocksize_b = self.args.row_blocksize, self.args.col_blocksize
+ if self.args.refine_residual_fp:
+ if self.layer_type in ["add_attn_in_re", "add_mlp_in_re"]:
+ self.apply_quantize_f, self.apply_quantize_b = False, False
+
+ if self.args.refine_ln_blocksize:
+ if self.layer_type in ["ln_attn_in"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ assert not (
+ self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward
+ ) # This will not happen at the same time
+ if self.args.refine_ln_blocksize_but_only_forward:
+ self.apply_quantize_f, self.apply_quantize_b = True, False
+ if self.args.refine_ln_blocksize_but_only_backward:
+ self.apply_quantize_f, self.apply_quantize_b = False, True
+
+ if self.layer_type in [
+ "ln_mlp_in",
+ ]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ assert not (
+ self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward
+ ) # This will not happen at the same time
+ if self.args.refine_ln_blocksize_but_only_forward:
+ self.apply_quantize_f, self.apply_quantize_b = True, False
+ if self.args.refine_ln_blocksize_but_only_backward:
+ self.apply_quantize_f, self.apply_quantize_b = False, True
+
+ if self.args.refine_attn_blocksize:
+ if self.layer_type in ["ln_attn_in"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.layer_type in ["attn_qkv_sum"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["add_attn_in_fx"]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.args.refine_mlp_blocksize:
+ if self.layer_type in [
+ "ln_mlp_in",
+ ]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.layer_type in ["mlp_act_sum"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["mlp_act_in"]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in [
+ "mul_act_in1",
+ ]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in [
+ "mul_act_in2",
+ ]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["add_mlp_in_fx"]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ def forward(self, Qinput, Iscale):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ if self.training:
+ return QuantAct_FPout.apply(
+ Qinput, Iscale, self.args, self.layer_name, self.apply_quantize_f, self.apply_quantize_b
+ )
+ else:
+ return Qinput
+
+
+class QuantAct_FPout(Function):
+ @staticmethod
+ def forward(ctx, Qinput, Iscale, args, layer_name, apply_quantize_f=True, apply_quantize_b=True):
+ ctx.saved = args, layer_name, apply_quantize_f, apply_quantize_b
+
+ # shrink Iscale to let the size of gradient the same as forward
+ ideal_scale_num = Qinput.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qinput, args.row_blocksize_f, args.col_blocksize_f)
+ # actual_scale_num = Qinput.numel() / (args.row_blocksize_f * args.col_blocksize_f)
+ assert Iscale.shape[0] == ideal_scale_num
+ Iscale = Iscale[: int(actual_scale_num), :, :]
+
+ Binput = block_cut(Qinput, args.row_blocksize_f, args.col_blocksize_f)
+ input = Binput * Iscale
+ input = block_reshape(input, Qinput, args.row_blocksize_f, args.col_blocksize_f)
+
+ if args.draw_distribution_forward:
+ save_tensor(input, None, None, fb="forward", aw="Activation", layer_name=layer_name)
+
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ args, layer_name, apply_quantize_f, apply_quantize_b = ctx.saved
+
+ Bgrad_output = block_cut(grad_output, args.row_blocksize_b, args.col_blocksize_b)
+ RQgrad_output, Qgrad_output, Gscale = block_quant(
+ Bgrad_output,
+ args.symm,
+ args.babit,
+ stochastic=True,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_b,
+ layer_name=layer_name,
+ )
+ Qgrad_output = block_reshape(Qgrad_output, grad_output, args.row_blocksize_b, args.col_blocksize_b)
+
+ if args.draw_distribution_backward:
+ save_tensor(grad_output, RQgrad_output, Qgrad_output, fb="backward", aw="Activation", layer_name=layer_name)
+
+ # enlarge grad_output to let the size of gradient the same as forward
+ ideal_scale_num = grad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(grad_output, args.row_blocksize_b, args.col_blocksize_b)
+ # actual_scale_num = grad_output.numel() / (args.row_blocksize_b * args.col_blocksize_b)
+ assert Gscale.shape[0] == actual_scale_num
+ Gscale = torch.nn.functional.pad(Gscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))
+
+ return Qgrad_output, Gscale, None, None, None, None
+
+
+class QAct_FPin(nn.Identity):
+ def __init__(self, args, normalize_before=False, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.normalize_before = normalize_before
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qact_config, f"{layer_type} not in qact_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qact_config[layer_type])
+ self.apply_quantize_f, self.apply_quantize_b = self.apply_quantize, self.apply_quantize
+
+ self.refine_rowcol_blocksize()
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply-f": self.apply_quantize_f, "apply-b": self.apply_quantize_b},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {
+ "row-f": self.args.row_blocksize_f,
+ "col-f": self.args.col_blocksize_f,
+ "row-b": self.args.row_blocksize_b,
+ "col-b": self.args.col_blocksize_b,
+ },
+ )
+ if quant_get_local_rank() == 0:
+ print(quantize_flag)
+
+ def refine_rowcol_blocksize(self):
+ self.args.row_blocksize_f, self.args.col_blocksize_f = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_b, self.args.col_blocksize_b = self.args.row_blocksize, self.args.col_blocksize
+
+ if self.args.refine_residual_fp:
+ if self.layer_type in ["re_attn_out_re", "re_mlp_out_re"]:
+ self.apply_quantize_f, self.apply_quantize_b = False, False
+
+ if self.args.refine_ln_blocksize:
+ if self.layer_type in ["re_attn_out_fx"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ assert not (
+ self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward
+ ) # This will not happen at the same time
+ if self.args.refine_ln_blocksize_but_only_forward:
+ self.apply_quantize_f, self.apply_quantize_b = True, False
+ if self.args.refine_ln_blocksize_but_only_backward:
+ self.apply_quantize_f, self.apply_quantize_b = False, True
+
+ if self.layer_type in ["re_mlp_out_fx"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ assert not (
+ self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward
+ ) # This will not happen at the same time
+ if self.args.refine_ln_blocksize_but_only_forward:
+ self.apply_quantize_f, self.apply_quantize_b = True, False
+ if self.args.refine_ln_blocksize_but_only_backward:
+ self.apply_quantize_f, self.apply_quantize_b = False, True
+
+ if self.args.refine_attn_blocksize:
+ if self.layer_type in ["re_attn_out_fx"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.layer_type in ["ln_attn_out"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["attn_q_in", "attn_k_in", "attn_v_in"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.args.refine_mlp_blocksize:
+ if self.layer_type in ["re_mlp_out_fx"]:
+ if self.args.refine_ln_pertoken:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ 1,
+ self.args.refine_row_blocksize * self.args.refine_col_blocksize,
+ )
+ else:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.layer_type in ["ln_mlp_out"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["mlp_act_gate", "mlp_act_up", "mul_act_out"]:
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["mlp_act_out"]:
+ self.args.row_blocksize_f, self.args.col_blocksize_f = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_b, self.args.col_blocksize_b = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ def forward(self, input):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ if self.training:
+ return QuantAct_FPin.apply(input, self.args, self.layer_name, self.apply_quantize_f, self.apply_quantize_b)
+ else:
+ return input, None
+
+
+class QuantAct_FPin(Function):
+ @staticmethod
+ def forward(ctx, input, args, layer_name, apply_quantize_f=True, apply_quantize_b=True):
+ ctx.saved = args, layer_name, apply_quantize_f, apply_quantize_b
+
+ Binput = block_cut(input, args.row_blocksize_f, args.col_blocksize_f)
+ RQinput, Qinput, Iscale = block_quant(
+ Binput,
+ args.symm,
+ args.fabit,
+ stochastic=False,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_f,
+ layer_name=layer_name,
+ )
+ Qinput = block_reshape(Qinput, input, args.row_blocksize_f, args.col_blocksize_f)
+ RQinput = block_reshape(RQinput, input, args.row_blocksize_f, args.col_blocksize_f)
+
+ if args.draw_distribution_forward:
+ save_tensor(input, RQinput, Qinput, fb="forward", aw="Activation", layer_name=layer_name)
+
+ # enlarge Iscale to let the size of gradient the same as forward
+ ideal_scale_num = input.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(input, args.row_blocksize_f, args.col_blocksize_f)
+ # actual_scale_num = input.numel() / (args.row_blocksize_f * args.col_blocksize_f)
+ assert Iscale.shape[0] == actual_scale_num
+ Iscale = torch.nn.functional.pad(Iscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))
+
+ return Qinput, Iscale
+
+ @staticmethod
+ def backward(ctx, Qgrad_output, Gscale):
+ args, layer_name, apply_quantize_f, apply_quantize_b = ctx.saved
+
+ # shrink Gscale to let the size of gradient the same as forward
+ ideal_scale_num = Qgrad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qgrad_output, args.row_blocksize_b, args.col_blocksize_b)
+ # actual_scale_num = Qgrad_output.numel() / (args.row_blocksize_b * args.col_blocksize_b)
+ assert Gscale.shape[0] == ideal_scale_num
+ Gscale = Gscale[: int(actual_scale_num), :, :]
+
+ Bgrad_output = block_cut(Qgrad_output, args.row_blocksize_b, args.col_blocksize_b)
+ grad_output = Bgrad_output * Gscale
+ grad_output = block_reshape(grad_output, Qgrad_output, args.row_blocksize_b, args.col_blocksize_b)
+
+ if args.draw_distribution_backward:
+ save_tensor(grad_output, None, None, fb="backward", aw="Activation", layer_name=layer_name)
+
+ return grad_output, None, None, None, None
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
+ Qinput, Binput, input, args, layer_type, name = (
+ Sum["Qinput"],
+ Sum["Binput"],
+ Sum["input"],
+ Sum["args"],
+ Sum["layer_type"],
+ Sum["name"],
+ )
+ if_nan, if_inf = check_nan_inf(input, True, False)
+ print(if_nan)
+
+ Q = block_quant(Binput, True, 8, stochastic=False, epsilon=1e-8)
diff --git a/llava/model/quantization/QAdd.py b/llava/model/quantization/QAdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..c486d6370bfdddb86bb3c4cd9f6416c0d13c4ee0
--- /dev/null
+++ b/llava/model/quantization/QAdd.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .QAct import QAct_FPin, QAct_FPout
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+ from .QAct import QAct_FPin, QAct_FPout
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QAdd(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qadd_config, f"{layer_type} not in qgelu_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qadd_config[layer_type])
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply": self.apply_quantize},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {"row": self.args.row_blocksize, "col": self.args.col_blocksize},
+ )
+ print(quantize_flag)
+
+ self.Add_in_re = QAct_FPout(args, layer_type=layer_type + "_in_re")
+ self.Add_in_fx = QAct_FPout(args, layer_type=layer_type + "_in_fx")
+
+ def forward(self, Qinput_re, Qinput_fx, Iscale_re, Iscale_fx):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ input1 = self.Add_in_re(Qinput_re, Iscale_re)
+ input2 = self.Add_in_fx(Qinput_fx, Iscale_fx)
+ output_fp = input1 + input2
+ return output_fp
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
diff --git a/llava/model/quantization/QFunction.py b/llava/model/quantization/QFunction.py
new file mode 100644
index 0000000000000000000000000000000000000000..196ae8f7f614004a48ab90e8fcb28dd44b628afd
--- /dev/null
+++ b/llava/model/quantization/QFunction.py
@@ -0,0 +1,280 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import re
+
+import torch
+
+try:
+ from .FloatPointQuantizeTorch import *
+ from .FloatPointQuantizeTriton import *
+except:
+ from FloatPointQuantizeTorch import *
+ from FloatPointQuantizeTriton import *
+
+
+def block_cut(input, row_block, column_block, pad_block=False):
+ # print(input.shape)
+ original_shape = input.shape
+ # input tensor shape is M * N
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ input = torch.nn.functional.pad(
+ input, (0, col_pad, 0, row_pad), "constant", 0
+ ) # refer to torch's doc to see why
+ M, N = input.shape[0], input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
+ assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
+ # print(input.shape)
+ input = (
+ input.reshape(row_num, row_block, column_num, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * column_num, row_block, column_block)
+ )
+ # print(input.shape)
+ return input
+
+
+def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
+ if len(origin_input.shape) > 2:
+ flatten_input = origin_input.reshape(-1, origin_input.shape[2])
+ elif len(origin_input.shape) == 2:
+ flatten_input = origin_input
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut")
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
+ M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ input = (
+ input.reshape(row_num, column_num, row_block, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * row_block, column_num * column_block)
+ )
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+ input = input[:M, :N]
+
+ if len(origin_input.shape) > 2:
+ input = input.reshape(origin_input.shape)
+ elif len(origin_input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block reshape")
+
+ return input
+
+
+def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
+ Binput = block_cut(input, row_block, column_block)
+ Binput = Binput.to(torch.float32)
+
+ for n in range(Binput.shape[0]):
+ unique_values = len(torch.unique(Binput[n, :, :]))
+ if unique_values > 256:
+ if necessary:
+ raise ValueError(f"{layer_type} contains more than 256 unique values.")
+ else:
+ return False
+ return True
+
+
+def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
+ Quant_fn = SymmQuantizer
+ return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)
+
+
+def extract_bit(string):
+ match = re.match(r"INT(\d+)", string)
+ if match:
+ return "integer", int(match.group(1)), None
+ match = re.match(r"E(\d+)M(\d+)", string)
+ if match:
+ Ebit, Mbit = int(match.group(1)), int(match.group(2))
+ if Ebit == 1:
+ return "integer", Mbit + 1, None
+ if Mbit == 0:
+ return "floatExM0", int(match.group(1)), 0
+ return "floatExMy", int(match.group(1)), int(match.group(2))
+ match = re.match(r"DE(\d+)", string)
+ if match:
+ return "Dynamic", int(match.group(1)), None
+ match = re.match(r"ZeroD(\d+)", string)
+ if match:
+ return "ZeroDynamic", int(match.group(1)), None
+ raise ValueError(f"{string} data format is not supported")
+
+
+class SymmQuantizer(torch.autograd.function.InplaceFunction):
+ @staticmethod
+ def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
+ with torch.no_grad():
+ absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon
+
+ if bits == "100" or not apply_quantize:
+ return input, input, torch.ones_like(absmax_per_block)
+ elif bits == "FP32":
+ return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
+ elif bits == "FP16":
+ return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
+ elif bits == "BF16":
+ return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
+ else:
+ QuantType, bit1, bit2 = extract_bit(bits)
+ if not symm:
+ bit1 = bit1 + 1 # pretend to be asymmtric
+
+ if QuantType == "integer":
+ Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
+ elif QuantType == "floatExMy":
+ Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
+ 2 ** (2 ** (bit1 - 1))
+ )
+ if bit1 == 4 and bit2 == 3:
+ Qn, Qp = -448, 448
+ if bit1 == 5 and bit2 == 2:
+ Qn, Qp = -57344, 57344
+ elif QuantType == "floatExM0":
+ Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
+ elif QuantType == "Dynamic":
+ Qn, Qp = -1, 1
+ elif QuantType == "ZeroDynamic":
+ Qn, Qp = -1, 1
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+ scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
+ scale_per_block = scale_per_block.to(input)
+
+ Qinput = input / scale_per_block
+
+ if QuantType == "integer":
+ if stochastic:
+ noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
+ Qinput.add_(noise)
+ Qinput.clamp_(Qn, Qp).round_()
+ elif QuantType == "floatExMy":
+ # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
+ Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
+ elif QuantType == "floatExM0":
+ Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
+ elif QuantType == "Dynamic":
+ Qinput = Dynamic_quantize_torch(Qinput, bit1, stochastic)
+ elif QuantType == "ZeroDynamic":
+ Qinput = ZeroDynamic_quantize_torch(Qinput, bit1, stochastic)
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+
+ RQinput = Qinput * scale_per_block
+ # print(f'Layer Name: {layer_name}: {input, Qinput, scale_per_block, absmax_per_block}', file=open('debug.txt', 'a'))
+
+ if input.dtype != Qinput.dtype:
+ print(
+ f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
+ file=open("debug.txt", "a"),
+ )
+ import IPython
+
+ IPython.embed()
+ return RQinput, Qinput, scale_per_block
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None, None, None, None, None
+
+
+if __name__ == "__main__":
+ import time
+
+ torch.manual_seed(0)
+ X = (torch.rand(2048, 4096).cuda() - 0.5) * 1000
+
+ X = X.to(torch.bfloat16)
+
+ B_X = block_cut(X, -1, -1)
+ RQ_X, Q_X, S_X = block_quant(
+ B_X, symm=True, bits="E2M0", stochastic=True, epsilon=1e-14, apply_quantize=True, layer_name=""
+ )
+ RQ_X = block_reshape(RQ_X, X, -1, -1)
+ Q_X = block_reshape(Q_X, X, -1, -1)
+
+ # start = time.time()
+ # for _ in range(10000):
+ # B_X = block_cut(X, 2, 2)
+ # RQ_X, Q_X, S_X = block_quant(B_X, symm=True, bits="E5M2", stochastic=True, epsilon=1e-14, apply_quantize=True, layer_name='')
+ # RQ_X = block_reshape(RQ_X, X, 2, 2)
+ # Q_X = block_reshape(Q_X, X, 2, 2)
+ # torch.cuda.synchronize()
+
+ # end = time.time()
+ # print(f"Time cost: {end - start}")
+
+ print(X.dtype)
+ import IPython
+
+ IPython.embed()
+
+ # X = torch.tensor([
+ # [1.1, 1, 1, 2, 2, 2, 3, 3, 3.01],
+ # [1.2, 1, 1, 2, 2, 2, 3, 3, 3.02],
+ # [4.1, 4, 4, 5, 5, 5, 6, 6, 6.01],
+ # [4.2, 4, 4, 5, 5, 5, 6, 6, 6.02],
+ # [7.1, 7, 7, 8, 8, 8, 9, 9, 9.01],
+ # [7.2, 7, 7, 8, 8, 8, 9, 9, 9.02],
+ # ])
+
+ # B_X = block_cut(X, 2, 3)
+ # print(B_X)
+ # print(B_X.shape)
+ # Q_X = block_quant(B_X, symm=True, bits=2, stochastic=False)
+ # print(Q_X)
+ # print(Q_X.shape)
diff --git a/llava/model/quantization/QGELU.py b/llava/model/quantization/QGELU.py
new file mode 100644
index 0000000000000000000000000000000000000000..26fb6f829bd69a8d64b3c437cab92d3271c01240
--- /dev/null
+++ b/llava/model/quantization/QGELU.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .QAct import QAct_FPin, QAct_FPout
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+ from .QAct import QAct_FPout, QAct_FPin
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QGELU(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qgelu_config, f"{layer_type} not in qgelu_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qgelu_config[layer_type])
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply": self.apply_quantize},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {"row": self.args.row_blocksize, "col": self.args.col_blocksize},
+ )
+ print(quantize_flag)
+
+ self.gelu = nn.GELU()
+ self.gelu_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.gelu_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, Qinput, Iscale):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ input_fp = self.gelu_in(Qinput, Iscale)
+ output_fp = self.gelu(input_fp)
+ Qoutput, Iscale = self.gelu_out(output_fp)
+ return Qoutput, Iscale
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
diff --git a/llava/model/quantization/QIdentity.py b/llava/model/quantization/QIdentity.py
new file mode 100644
index 0000000000000000000000000000000000000000..931d89d2aa5d9feac680d9429b27fd8c33a3ed0b
--- /dev/null
+++ b/llava/model/quantization/QIdentity.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QIdentity(nn.Identity):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input, scale):
+ input = QuantIdentity.apply(input, scale)
+
+ return input
+
+
+class QuantIdentity(Function):
+ @staticmethod
+ def forward(ctx, input, scale):
+ return input, scale
+
+ @staticmethod
+ def backward(ctx, grad_output, Gscale):
+ import IPython
+
+ IPython.embed()
+ return grad_output, Gscale
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
+ Qinput, Binput, input, args, layer_type, name = (
+ Sum["Qinput"],
+ Sum["Binput"],
+ Sum["input"],
+ Sum["args"],
+ Sum["layer_type"],
+ Sum["name"],
+ )
+ if_nan, if_inf = check_nan_inf(input, True, False)
+ print(if_nan)
+
+ Q = block_quant(Binput, True, 8, stochastic=False, epsilon=1e-8)
diff --git a/llava/model/quantization/QLayerNorm.py b/llava/model/quantization/QLayerNorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a84b41ab801268715385d64e970712802d2b98c
--- /dev/null
+++ b/llava/model/quantization/QLayerNorm.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .QAct import QAct_FPin, QAct_FPout
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+ from .QAct import QAct_FPin, QAct_FPout
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QLayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qlayernorm_config, f"{layer_type} not in qlayernorm_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qlayernorm_config[layer_type])
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply": self.apply_quantize},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {"row": self.args.row_blocksize, "col": self.args.col_blocksize},
+ )
+ print(quantize_flag)
+
+ self.ln_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.layer_norm = nn.LayerNorm(normalized_shape, eps=eps)
+ self.ln_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, Qinput, Iscale):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ input = self.ln_in(Qinput, Iscale)
+ output_fp = self.layer_norm(input)
+ # import IPython
+ # IPython.embed()
+ output, scale = self.ln_out(output_fp)
+ return output, scale
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
diff --git a/llava/model/quantization/QLinear.py b/llava/model/quantization/QLinear.py
new file mode 100644
index 0000000000000000000000000000000000000000..0174dca1019373c27ccd2e5e3031d37faa0bab26
--- /dev/null
+++ b/llava/model/quantization/QLinear.py
@@ -0,0 +1,333 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, InplaceFunction
+from torch.cuda import amp
+
+from .Qconfig import qconfig
+from .QFunction import *
+from .utils import *
+
+
+class QLinear(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True, args=None, layer_type=""):
+ super().__init__(in_features, out_features, bias)
+ self.args = deepcopy(args)
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qlinear_config.keys(), f"{layer_type} not in qlinear_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qlinear_config[layer_type])
+ self.apply_quantize_fw, self.apply_quantize_fo, self.apply_quantize_bw, self.apply_quantize_ba = (
+ self.apply_quantize,
+ self.apply_quantize,
+ self.apply_quantize,
+ self.apply_quantize,
+ )
+
+ self.refine_rowcol_blocksize()
+
+ self.fbit = self.args.fwbit if self.args.fwbit else self.Ubit
+ self.bbit = self.args.bwbit if self.args.bwbit else self.Ubit
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {
+ "apply-fw": self.apply_quantize_fw,
+ "apply-fo": self.apply_quantize_fo,
+ "apply-bw": self.apply_quantize_bw,
+ "apply-ba": self.apply_quantize_ba,
+ },
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {
+ "row-fa": self.args.row_blocksize_fa,
+ "col-fa": self.args.col_blocksize_fa,
+ "row-fw": self.args.row_blocksize_fw,
+ "col-fw": self.args.col_blocksize_fw,
+ "row-fo": self.args.row_blocksize_fo,
+ "col-fo": self.args.col_blocksize_fo,
+ "row-ba": self.args.row_blocksize_ba,
+ "col-ba": self.args.col_blocksize_ba,
+ "row-bw": self.args.row_blocksize_bw,
+ "col-bw": self.args.col_blocksize_bw,
+ "row-bo": self.args.row_blocksize_bo,
+ "col-bo": self.args.col_blocksize_bo,
+ },
+ )
+ if quant_get_local_rank() == 0:
+ print(quantize_flag)
+
+ def refine_rowcol_blocksize(self):
+ self.args.row_blocksize_fa, self.args.col_blocksize_fa = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_fw, self.args.col_blocksize_fw = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_fo, self.args.col_blocksize_fo = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_ba, self.args.col_blocksize_ba = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_bw, self.args.col_blocksize_bw = self.args.row_blocksize, self.args.col_blocksize
+ self.args.row_blocksize_bo, self.args.col_blocksize_bo = self.args.row_blocksize, self.args.col_blocksize
+
+ if self.args.refine_attn_blocksize:
+ if self.layer_type in ["attn_q", "attn_k", "attn_v"]:
+ self.apply_quantize_fo = False
+ self.args.row_blocksize_ba, self.args.col_blocksize_ba = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ if self.layer_type in ["attn_proj"]:
+ self.apply_quantize_ba = False
+ self.args.row_blocksize_fo, self.args.col_blocksize_fo = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ if self.args.refine_mlp_blocksize:
+ if self.layer_type in ["mlp_gate", "mlp_up", "mlp_down"]:
+ self.args.row_blocksize_fo, self.args.col_blocksize_fo = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+ self.args.row_blocksize_ba, self.args.col_blocksize_ba = (
+ self.args.refine_row_blocksize,
+ self.args.refine_col_blocksize,
+ )
+
+ def forward(self, Qinput, Iscale):
+ if self.training:
+ output = QuantLinear.apply(
+ Qinput,
+ Iscale,
+ self.weight,
+ self.bias,
+ self.args,
+ self.layer_name,
+ self.apply_quantize_fw,
+ self.apply_quantize_fo,
+ self.apply_quantize_bw,
+ self.apply_quantize_ba,
+ )
+ return output
+ else:
+ output = F.linear(Qinput, self.weight, self.bias)
+ return output, None
+
+
+# class QuantLinear(Function):
+# @staticmethod
+# def forward(ctx, input, weight, bias, args, layer_type):
+# ctx.saved = input, weight, bias, args, layer_type
+# return F.linear(input, weight, bias)
+#
+# @staticmethod
+# def backward(ctx, grad_output):
+# input, weight, bias, args, layer_type = ctx.saved
+#
+# C_in = input.shape[-1]
+# C_out = grad_output.shape[-1]
+#
+# grad_output_flatten = grad_output.reshape(-1, C_out)
+# input_flatten = input.reshape(-1, C_in)
+#
+# if grad_output_flatten.dtype == input_flatten.dtype:
+# grad_weight = grad_output_flatten.t().mm(input_flatten)
+# else:
+# grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+#
+# if grad_output_flatten.dtype == weight.dtype:
+# grad_input = grad_output_flatten.mm(weight)
+# else:
+# grad_input = grad_output_flatten.float().mm(weight)
+#
+# if bias is not None:
+# grad_bias = grad_output_flatten.sum(0)
+# else:
+# grad_bias = None
+#
+# grad_input_transform = grad_input.reshape(input.size())
+#
+# return grad_input_transform, grad_weight, grad_bias, None, None
+
+# B%% = block_cut(%%, args.row_blocksize, args.col_blocksize)
+# RQ%%, Q%%, Wscale = block_quant(B%%, args.symm, args.fwbit, stochastic=False, epsilon=args.epsilon)
+# Q%% = block_reshape(Q%%, %%, args.row_blocksize, args.col_blocksize)
+# RQ%% = block_reshape(RQ%%, %%, args.row_blocksize, args.col_blocksize)
+
+
+class QuantLinear(Function):
+ @staticmethod
+ @amp.custom_fwd(cast_inputs=torch.bfloat16)
+ def forward(
+ ctx,
+ Qinput,
+ Iscale,
+ weight,
+ bias,
+ args,
+ layer_name,
+ apply_quantize_fw=True,
+ apply_quantize_fo=True,
+ apply_quantize_bw=True,
+ apply_quantize_ba=True,
+ ):
+
+ # shrink Iscale to let the size of gradient the same as forward
+ ideal_scale_num = Qinput.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
+ # actual_scale_num = Qinput.numel() / (args.row_blocksize_fa * args.col_blocksize_fa)
+ assert Iscale.shape[0] == ideal_scale_num
+ Iscale = Iscale[: int(actual_scale_num), :, :]
+
+ Binput = block_cut(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
+ RQinput = Binput * Iscale
+ RQinput = block_reshape(RQinput, Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
+
+ Bweight = block_cut(weight, args.row_blocksize_fw, args.col_blocksize_fw)
+ RQweight, Qweight, Wscale = block_quant(
+ Bweight,
+ args.symm,
+ args.fwbit,
+ stochastic=False,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_fw,
+ layer_name=layer_name + "WeightQuant",
+ )
+ Qweight = block_reshape(Qweight, weight, args.row_blocksize_fw, args.col_blocksize_fw)
+ RQweight = block_reshape(RQweight, weight, args.row_blocksize_fw, args.col_blocksize_fw)
+
+ if args.draw_distribution_forward:
+ save_tensor(weight, Qweight, RQweight, fb="forward", aw="Weight", layer_name=layer_name)
+
+ ctx.saved = Qinput, Iscale, Qweight, Wscale, bias, args, layer_name
+ ctx.apply_quantize = apply_quantize_fw, apply_quantize_fo, apply_quantize_bw, apply_quantize_ba
+ fc_output = F.linear(RQinput, RQweight, bias)
+
+ Bfc_output = block_cut(fc_output, args.row_blocksize_fo, args.col_blocksize_fo)
+ RQfc_output, Qfc_output, Oscale = block_quant(
+ Bfc_output,
+ args.symm,
+ args.fabit,
+ stochastic=False,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_fo,
+ layer_name=layer_name + "LinearOutput",
+ )
+ RQfc_output = block_reshape(RQfc_output, fc_output, args.row_blocksize_fo, args.col_blocksize_fo)
+ Qfc_output = block_reshape(Qfc_output, fc_output, args.row_blocksize_fo, args.col_blocksize_fo)
+
+ if args.draw_distribution_forward:
+ save_tensor(fc_output, Qfc_output, RQfc_output, fb="forward", aw="Output", layer_name=layer_name)
+
+ # enlarge Oscale to let the size of gradient the same as forward
+ ideal_scale_num = Qfc_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qfc_output, args.row_blocksize_fo, args.col_blocksize_fo)
+ # actual_scale_num = Qfc_output.numel() / (args.row_blocksize_fo * args.col_blocksize_fo)
+ assert Oscale.shape[0] == actual_scale_num
+ Oscale = torch.nn.functional.pad(Oscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))
+
+ return Qfc_output, Oscale
+
+ @staticmethod
+ @amp.custom_bwd
+ def backward(ctx, Qgrad_output, Gscale):
+ Qinput, Iscale, Qweight, Wscale, bias, args, layer_name = ctx.saved
+ apply_quantize_fw, apply_quantize_fo, apply_quantize_bw, apply_quantize_ba = ctx.apply_quantize
+
+ # shrink Gscale to let the size of gradient the same as forward
+ ideal_scale_num = Qgrad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)
+ # actual_scale_num = Qgrad_output.numel() / (args.row_blocksize_bo * args.col_blocksize_bo)
+ assert Gscale.shape[0] == ideal_scale_num
+ Gscale = Gscale[: int(actual_scale_num), :, :]
+
+ Bgrad_output = block_cut(Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)
+ RQgrad_output = Bgrad_output * Gscale
+ grad_output = block_reshape(RQgrad_output, Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)
+
+ if args.draw_distribution_backward:
+ save_tensor(
+ grad_output, Qgrad_output, RQgrad_output, fb="backward in", aw="Activation", layer_name=layer_name
+ )
+
+ C_in = Qinput.shape[-1]
+ C_out = Qgrad_output.shape[-1]
+
+ Binput = block_cut(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
+ input = Binput * Iscale
+ input = block_reshape(input, Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
+
+ grad_output_flatten = grad_output.reshape(-1, C_out)
+ input_flatten = input.reshape(-1, C_in)
+
+ if grad_output_flatten.dtype == input_flatten.dtype:
+ grad_weight = grad_output_flatten.t().mm(input_flatten)
+ else:
+ grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+ Bgrad_weight = block_cut(grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)
+ RQgrad_weight, Qgrad_weight, GWscale = block_quant(
+ Bgrad_weight,
+ args.symm,
+ args.bwbit,
+ stochastic=True,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_bw,
+ layer_name=layer_name + "WeightGradient",
+ )
+ Qgrad_weight = block_reshape(Qgrad_weight, grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)
+ RQgrad_weight = block_reshape(RQgrad_weight, grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)
+
+ if args.draw_distribution_backward:
+ save_tensor(grad_weight, Qgrad_weight, RQgrad_weight, fb="backward", aw="Weight", layer_name=layer_name)
+
+ # Calculate Weight Gradient
+ Bweight = block_cut(Qweight, args.row_blocksize_fw, args.col_blocksize_fw)
+ weight = Bweight * Wscale
+ weight = block_reshape(weight, Qweight, args.row_blocksize_fw, args.col_blocksize_fw)
+
+ if grad_output_flatten.dtype == Qweight.dtype:
+ grad_input = grad_output_flatten.mm(weight)
+ else:
+ grad_input = grad_output_flatten.float().mm(weight)
+
+ Bgrad_input = block_cut(grad_input, args.row_blocksize_ba, args.col_blocksize_ba)
+ RQgrad_input, Qgrad_input, GIscale = block_quant(
+ Bgrad_input,
+ args.symm,
+ args.babit,
+ stochastic=True,
+ epsilon=args.epsilon,
+ apply_quantize=apply_quantize_ba,
+ layer_name=layer_name + "ActivationGradient",
+ )
+ Qgrad_input = block_reshape(Qgrad_input, grad_input, args.row_blocksize_ba, args.col_blocksize_ba)
+ RQgrad_input = block_reshape(RQgrad_input, grad_input, args.row_blocksize_ba, args.col_blocksize_ba)
+
+ if args.draw_distribution_backward:
+ save_tensor(
+ grad_input, Qgrad_input, RQgrad_input, fb="backward out", aw="Activation out", layer_name=layer_name
+ )
+
+ # enlarge Qgrad_input to let the size of gradient the same as forward
+ ideal_scale_num = Qgrad_input.numel() / (args.min_blockunit_row * args.min_blockunit_col)
+ actual_scale_num = calculate_scale_num(Qgrad_input, args.row_blocksize_ba, args.col_blocksize_ba)
+ # actual_scale_num = Qgrad_input.numel() / (args.row_blocksize_ba * args.col_blocksize_ba)
+ assert GIscale.shape[0] == actual_scale_num
+ GIscale = torch.nn.functional.pad(GIscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))
+
+ Qgrad_input_transform = Qgrad_input.reshape(Qinput.size())
+
+ if bias is not None:
+ grad_bias = grad_output_flatten.sum(0)
+ else:
+ grad_bias = None
+
+ return Qgrad_input_transform, GIscale, RQgrad_weight, grad_bias, None, None, None, None, None, None
diff --git a/llava/model/quantization/QMul.py b/llava/model/quantization/QMul.py
new file mode 100644
index 0000000000000000000000000000000000000000..5026b144bf2fab1c70f4de25c901fc2a4c671d01
--- /dev/null
+++ b/llava/model/quantization/QMul.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch.autograd.function import Function, InplaceFunction
+
+try:
+ from .QAct import QAct_FPin, QAct_FPout
+ from .Qconfig import qconfig
+ from .QFunction import *
+ from .utils import *
+
+except:
+ from Qconfig import qconfig
+ from utils import *
+ from QFunction import *
+ from .QAct import QAct_FPin, QAct_FPout
+
+import os
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+
+
+class QMul(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.layer_type = layer_type
+ assert layer_type != "", "layer_type is not defined"
+ assert layer_type in qconfig.qmul_config, f"{layer_type} not in qgelu_config"
+
+ self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qmul_config[layer_type])
+
+ self.fbit = self.args.fabit if self.args.fabit else self.Ubit
+ self.bbit = self.args.babit if self.args.babit else self.Ubit
+
+ quantize_flag = format_string_with_condition(
+ layer_type,
+ {"apply": self.apply_quantize},
+ self.args.symm,
+ self.fbit,
+ self.bbit,
+ {"row": self.args.row_blocksize, "col": self.args.col_blocksize},
+ )
+
+ print(quantize_flag)
+
+ self.Mul_in1 = QAct_FPout(args, layer_type=layer_type + "_in1")
+ self.Mul_in2 = QAct_FPout(args, layer_type=layer_type + "_in2")
+ self.Mul_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, Qinput1, Qinput2, Iscale1, Iscale2):
+ # input shape is (Batch Size, Sequence Length, Hidden Size)
+ input1 = self.Mul_in1(Qinput1, Iscale1)
+ input2 = self.Mul_in2(Qinput2, Iscale2)
+ output_fp = input1 * input2
+ Qoutput, Oscale = self.Mul_out(output_fp)
+ return Qoutput, Oscale
+
+
+if __name__ == "__main__":
+ Sum = torch.load("tensor/QAct_nan_epoch16.pt")
diff --git a/llava/model/quantization/Qconfig.py b/llava/model/quantization/Qconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..e07a9f718209f5c9aa15408ba542f34bc80497bf
--- /dev/null
+++ b/llava/model/quantization/Qconfig.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+class QuantizationConfig:
+ def __init__(self):
+ self.qlinear_config = {
+ "mlp_gate": {"all", "linear"},
+ "mlp_up": {"all", "linear"},
+ "mlp_down": {"all", "linear"},
+ "attn_proj": {"all", "linear"},
+ "attn_q": {"all", "linear"},
+ "attn_k": {"all", "linear"},
+ "attn_v": {"all", "linear"},
+ }
+ self.qact_config = {
+ "mul_act_in1": {"all", "gelu"},
+ "mul_act_in2": {"all", "gelu", "te_like"},
+ "mul_act_out": {"all", "gelu", "te_like"},
+ "mlp_act_sum": {"all", "mlp", "te_like"},
+ "mlp_act_gate": {"all", "mlp", "te_like"},
+ "mlp_act_up": {"all", "mlp", "te_like"},
+ "mlp_act_in": {"all", "mlp", "te_like"},
+ "mlp_act_out": {"all", "mlp"},
+ "ln_attn_in": {"all", "layernorm"},
+ "ln_mlp_in": {"all", "layernorm"},
+ "ln_attn_out": {"all", "layernorm", "te_like"},
+ "ln_mlp_out": {"all", "layernorm", "te_like"},
+ "add_attn_in_re": {"all", "residual"},
+ "add_attn_in_fx": {"all", "residual", "te_like"},
+ "add_mlp_in_re": {"all", "residual"},
+ "add_mlp_in_fx": {"all", "residual", "te_like"},
+ "re_attn_out_re": {"all", "residual"},
+ "re_attn_out_fx": {"all", "residual"},
+ "re_mlp_out_re": {"all", "residual"},
+ "re_mlp_out_fx": {"all", "residual"},
+ "attn_qkv_sum": {"all", "attn", "te_like"},
+ "attn_q_in": {"all", "attn", "te_like"},
+ "attn_k_in": {"all", "attn", "te_like"},
+ "attn_v_in": {"all", "attn", "te_like"},
+ "attn_q_out": {"all", "attn", "te_like"},
+ "attn_k_out": {"all", "attn", "te_like"},
+ "attn_v_out": {"all", "attn", "te_like"},
+ "attn_proj_in": {"all", "attn", "te_like"},
+ }
+
+ self.qgelu_config = {"mlp_gelu": {"all", "gelu"}}
+
+ self.qlayernorm_config = {"ln_attn": {"all", "layernorm"}, "ln_mlp": {"all", "layernorm"}}
+
+ self.qadd_config = {"add_attn": {"all", "residual"}, "add_mlp": {"all", "residual"}}
+
+ self.qmul_config = {
+ "mul_act": {"all", "gelu"},
+ }
+
+
+qconfig = QuantizationConfig()
diff --git a/llava/model/quantization/__init__.py b/llava/model/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceeada93adcc855e13dc5f9414332a74ebaee751
--- /dev/null
+++ b/llava/model/quantization/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .FloatPointQuantizeTorch import *
+from .QAct import *
+from .QAdd import *
+from .QGELU import *
+from .QIdentity import *
+from .QLayerNorm import *
+from .QLinear import *
+from .QMul import *
+from .utils import *
diff --git a/llava/model/quantization/debug.txt b/llava/model/quantization/debug.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f599ed2db915976a44cefe435cc48fe5f413be6
--- /dev/null
+++ b/llava/model/quantization/debug.txt
@@ -0,0 +1,31 @@
+Layer Name: : (tensor([[[ 44., -127.],
+ [ 38., -96.]],
+
+ [[ -8., -127.],
+ [ -5., -110.]],
+
+ [[ 10., 127.],
+ [ 5., 102.]],
+
+ ...,
+
+ [[ 127., -2.],
+ [ 45., -0.]],
+
+ [[ 21., -127.],
+ [ 42., -1.]],
+
+ [[ -16., 77.],
+ [ 35., 127.]]], device='cuda:5', dtype=torch.float16), tensor([[[ 4444.]],
+
+ [[ 2476.]],
+
+ [[ 1041.]],
+
+ ...,
+
+ [[ 2030.]],
+
+ [[ 4240.]],
+
+ [[20112.]]], device='cuda:5', dtype=torch.float16))
diff --git a/llava/model/quantization/utils.py b/llava/model/quantization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..95696238ebd8c92276eb6bd56367c2ed4b26b6dd
--- /dev/null
+++ b/llava/model/quantization/utils.py
@@ -0,0 +1,213 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def list_has_common_element(list1, list2):
+ set1 = set(list1)
+ set2 = set(list2)
+ return len(set1.intersection(set2)) > 0
+
+
+def calculate_scale_num(input, row_block, col_block):
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if col_block == -1:
+ col_block = N
+
+ return input.numel() / (row_block * col_block)
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+def format_string_with_condition(
+ input_string,
+ condition_config,
+ symm,
+ fbit,
+ bbit,
+ blocksize_config,
+ input_pad=20,
+):
+ padded_string = input_string.ljust(input_pad)
+ output_string = padded_string
+
+ for k, v in condition_config.items():
+ if v:
+ output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6)
+ else:
+ output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6)
+
+ output_string = (
+ output_string + f"Symm {symm}".ljust(10) + f"Forward bit {fbit}".ljust(20) + f"Backward bit {bbit}".ljust(20)
+ )
+ for k, v in blocksize_config.items():
+ output_string += f"{k}: {v}".ljust(15)
+
+ return output_string
+
+
+def print_warning(sentence):
+ print("*" * (len(sentence) + 4))
+ print(f"* {sentence} *")
+ print("*" * (len(sentence) + 4))
+
+
+def check_nan_inf(tensor, check_nan, check_inf):
+ if check_nan:
+ contain_nan = torch.isnan(tensor).any()
+ else:
+ contain_nan = False
+ if check_inf:
+ contain_inf = torch.isinf(tensor).any()
+ else:
+ contain_inf = False
+ return contain_nan, contain_inf
+
+
+def move_torch_to_numpy(tensor):
+ if tensor is None:
+ return None
+
+ if tensor.is_cuda:
+ tensor = tensor.cpu()
+ return tensor.detach().float().numpy()
+
+
+def flatten_to_1d(tensor):
+ if tensor is None:
+ return None
+
+ return tensor.reshape(-1)
+
+
+def get_uniform_bin(tensor, num_bins, blank=0.05):
+ bin_arr = np.linspace(
+ tensor.min() - (tensor.max() - tensor.min()) * blank,
+ tensor.max() + (tensor.max() - tensor.min()) * blank,
+ num_bins,
+ endpoint=True,
+ )
+ return bin_arr
+
+
+def determine_log_scale_hist(counts, threshold_ratio=3):
+ # ๆพๅฐๆๅคง็ bin ๅๆฌกๅคง็ bin
+ max_count = np.max(counts)
+ third_max_count = np.partition(counts, -3)[-3]
+
+ # ๅคๆญๆฏๅฆไฝฟ็จๅฏนๆฐๅปๅบฆ
+ if max_count >= threshold_ratio * third_max_count:
+ return True
+ else:
+ return False
+
+
+def print_list_with_separator(lst):
+ separator = "-" * 30 # ้ฟ้ฟ็ๅๅฒ็บฟ
+
+ for item in lst:
+ print(item, item.dtype)
+ print(separator)
+
+
+def save_tensor(tensor, RQtensor, Qtensor, fb, aw, layer_name):
+ visualize_path = os.path.join("visualize", aw, fb)
+ file_name = f"{layer_name}.pt"
+ os.makedirs(visualize_path, exist_ok=True)
+ torch.save(
+ {"tensor": tensor, "RQtensor": RQtensor, "Qtensor": Qtensor, "fb": fb, "aw": aw, "layer_name": layer_name},
+ os.path.join(visualize_path, file_name),
+ )
+ print(f"{aw} {fb} {layer_name} saved!")
+
+
+def visualize_distribution(pt_path):
+ print(pt_path)
+ saved_tensor = torch.load(pt_path, map_location="cpu")
+ # os.remove(pt_path)
+
+ tensor = saved_tensor["tensor"]
+ RQtensor = saved_tensor["RQtensor"]
+ Qtensor = saved_tensor["Qtensor"]
+ fb = saved_tensor["fb"]
+ aw = saved_tensor["aw"]
+ layer_name = saved_tensor["layer_name"]
+
+ # visualize_path = os.path.join("visualize", aw, fb, layer_name)
+ # file_name = "distribution.png"
+ # os.makedirs(visualize_path, exist_ok=True)
+ visualize_path = os.path.join("visualize", aw, fb)
+ file_name = f"{layer_name}.png"
+ os.makedirs(visualize_path, exist_ok=True)
+
+ # MSE = (tensor - Qtensor).norm().item()
+ tensor, RQtensor, Qtensor = move_torch_to_numpy(tensor), move_torch_to_numpy(RQtensor), move_torch_to_numpy(Qtensor)
+ tensor, RQtensor, Qtensor = flatten_to_1d(tensor), flatten_to_1d(RQtensor), flatten_to_1d(Qtensor)
+ # ๅๅปบไธคไธชๅญๅพ
+ fig, axs = plt.subplots(3, 2, figsize=(120, 80))
+ plt.rcParams["font.size"] = 80
+ for ax in axs.flatten():
+ ax.tick_params(axis="both", labelsize=80)
+
+ num_bins = 1000
+ # Tensor distribution - original
+ if tensor is not None:
+ axs[0, 0].hist(tensor, bins=num_bins, color="blue", alpha=0.5)
+ axs[0, 0].set_title(f"Original Distribution of tensor, {tensor.dtype}")
+
+ # Tensor distribution - log scale
+ axs[0, 1].hist(tensor, bins=num_bins, color="blue", alpha=0.5)
+ axs[0, 1].set_yscale("log")
+ axs[0, 1].set_title(f"Log Scale Distribution of tensor, {tensor.dtype}")
+ axs[0, 1].set_xlabel("use log scale")
+
+ # Qtensor distribution - original
+ if RQtensor is not None:
+ axs[1, 0].hist(RQtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[1, 0].set_title(f"Original Distribution of RQtensor, {Qtensor.dtype}")
+
+ # Qtensor distribution - log scale
+ axs[1, 1].hist(RQtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[1, 1].set_yscale("log")
+ axs[1, 1].set_title(f"Log Scale Distribution of RQtensor, {Qtensor.dtype}")
+ axs[1, 1].set_xlabel("use log scale")
+
+ # Qtensor distribution - original
+ if Qtensor is not None:
+ Q_outlier = np.max(np.abs(Qtensor))
+ axs[2, 0].hist(Qtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[2, 0].set_title(f"Original Distribution of Qtensor, {Qtensor.dtype}")
+ # axs[2, 0].set_xlim(-Q_outlier, Q_outlier)
+
+ # Qtensor distribution - log scale
+ axs[2, 1].hist(Qtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[2, 1].set_yscale("log")
+ axs[2, 1].set_title(f"Log Scale Distribution of Qtensor, {Qtensor.dtype}")
+ axs[2, 1].set_xlabel("use log scale")
+ # axs[2, 1].set_xlim(-Q_outlier, Q_outlier)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(visualize_path, file_name))
+ plt.close(fig)
+ print(f"{aw} {fb} {layer_name} distribution finish!")
+
+ exit(0)
diff --git a/llava/model/qutils.py b/llava/model/qutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0d754a586f78bca6bb1209f3fa12bb2d6351bde
--- /dev/null
+++ b/llava/model/qutils.py
@@ -0,0 +1,211 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def list_has_common_element(list1, list2):
+ set1 = set(list1)
+ set2 = set(list2)
+ return len(set1.intersection(set2)) > 0
+
+
+def calculate_scale_num(input, row_block, col_block):
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if col_block == -1:
+ col_block = N
+
+ return input.numel() / (row_block * col_block)
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+def format_string_with_condition(
+ input_string,
+ condition_config,
+ symm,
+ bits,
+ blocksize_config,
+ input_pad=20,
+):
+ padded_string = input_string.ljust(input_pad)
+ output_string = padded_string
+
+ for k, v in condition_config.items():
+ if v:
+ output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6)
+ else:
+ output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6)
+
+ output_string = output_string + f"Symm {symm}".ljust(10)
+
+ for k, v in bits.items():
+ output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10)
+ for k, v in blocksize_config.items():
+ output_string += f"{k}: {v}".ljust(15)
+
+ return output_string
+
+
+def print_warning(sentence):
+ print("*" * (len(sentence) + 4))
+ print(f"* {sentence} *")
+ print("*" * (len(sentence) + 4))
+
+
+def check_nan_inf(tensor, check_nan, check_inf):
+ if check_nan:
+ contain_nan = torch.isnan(tensor).any()
+ else:
+ contain_nan = False
+ if check_inf:
+ contain_inf = torch.isinf(tensor).any()
+ else:
+ contain_inf = False
+ return contain_nan, contain_inf
+
+
+def move_torch_to_numpy(tensor):
+ if tensor is None:
+ return None
+
+ if tensor.is_cuda:
+ tensor = tensor.cpu()
+ return tensor.detach().float().numpy()
+
+
+def flatten_to_1d(tensor):
+ if tensor is None:
+ return None
+
+ return tensor.reshape(-1)
+
+
+def get_uniform_bin(tensor, num_bins, blank=0.05):
+ bin_arr = np.linspace(
+ tensor.min() - (tensor.max() - tensor.min()) * blank,
+ tensor.max() + (tensor.max() - tensor.min()) * blank,
+ num_bins,
+ endpoint=True,
+ )
+ return bin_arr
+
+
+def determine_log_scale_hist(counts, threshold_ratio=3):
+ max_count = np.max(counts)
+ third_max_count = np.partition(counts, -3)[-3]
+
+ if max_count >= threshold_ratio * third_max_count:
+ return True
+ else:
+ return False
+
+
+def print_list_with_separator(lst):
+ separator = "-" * 30
+
+ for item in lst:
+ print(item, item.dtype)
+ print(separator)
+
+
+def save_tensor(tensor, RQtensor, Qtensor, fb, aw, layer_name):
+ visualize_path = os.path.join("visualize", aw, fb)
+ file_name = f"{layer_name}.pt"
+ os.makedirs(visualize_path, exist_ok=True)
+ torch.save(
+ {"tensor": tensor, "RQtensor": RQtensor, "Qtensor": Qtensor, "fb": fb, "aw": aw, "layer_name": layer_name},
+ os.path.join(visualize_path, file_name),
+ )
+ print(f"{aw} {fb} {layer_name} saved!")
+
+
+def visualize_distribution(pt_path):
+ print(pt_path)
+ saved_tensor = torch.load(pt_path, map_location="cpu")
+ # os.remove(pt_path)
+
+ tensor = saved_tensor["tensor"]
+ RQtensor = saved_tensor["RQtensor"]
+ Qtensor = saved_tensor["Qtensor"]
+ fb = saved_tensor["fb"]
+ aw = saved_tensor["aw"]
+ layer_name = saved_tensor["layer_name"]
+
+ # visualize_path = os.path.join("visualize", aw, fb, layer_name)
+ # file_name = "distribution.png"
+ # os.makedirs(visualize_path, exist_ok=True)
+ visualize_path = os.path.join("visualize", aw, fb)
+ file_name = f"{layer_name}.png"
+ os.makedirs(visualize_path, exist_ok=True)
+
+ # MSE = (tensor - Qtensor).norm().item()
+ tensor, RQtensor, Qtensor = move_torch_to_numpy(tensor), move_torch_to_numpy(RQtensor), move_torch_to_numpy(Qtensor)
+ tensor, RQtensor, Qtensor = flatten_to_1d(tensor), flatten_to_1d(RQtensor), flatten_to_1d(Qtensor)
+
+ fig, axs = plt.subplots(3, 2, figsize=(120, 80))
+ plt.rcParams["font.size"] = 80
+ for ax in axs.flatten():
+ ax.tick_params(axis="both", labelsize=80)
+
+ num_bins = 1000
+ # Tensor distribution - original
+ if tensor is not None:
+ axs[0, 0].hist(tensor, bins=num_bins, color="blue", alpha=0.5)
+ axs[0, 0].set_title(f"Original Distribution of tensor, {tensor.dtype}")
+
+ # Tensor distribution - log scale
+ axs[0, 1].hist(tensor, bins=num_bins, color="blue", alpha=0.5)
+ axs[0, 1].set_yscale("log")
+ axs[0, 1].set_title(f"Log Scale Distribution of tensor, {tensor.dtype}")
+ axs[0, 1].set_xlabel("use log scale")
+
+ # Qtensor distribution - original
+ if RQtensor is not None:
+ axs[1, 0].hist(RQtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[1, 0].set_title(f"Original Distribution of RQtensor, {Qtensor.dtype}")
+
+ # Qtensor distribution - log scale
+ axs[1, 1].hist(RQtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[1, 1].set_yscale("log")
+ axs[1, 1].set_title(f"Log Scale Distribution of RQtensor, {Qtensor.dtype}")
+ axs[1, 1].set_xlabel("use log scale")
+
+ # Qtensor distribution - original
+ if Qtensor is not None:
+ Q_outlier = np.max(np.abs(Qtensor))
+ axs[2, 0].hist(Qtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[2, 0].set_title(f"Original Distribution of Qtensor, {Qtensor.dtype}")
+ # axs[2, 0].set_xlim(-Q_outlier, Q_outlier)
+
+ # Qtensor distribution - log scale
+ axs[2, 1].hist(Qtensor, bins=num_bins, color="red", alpha=0.5)
+ axs[2, 1].set_yscale("log")
+ axs[2, 1].set_title(f"Log Scale Distribution of Qtensor, {Qtensor.dtype}")
+ axs[2, 1].set_xlabel("use log scale")
+ # axs[2, 1].set_xlim(-Q_outlier, Q_outlier)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(visualize_path, file_name))
+ plt.close(fig)
+ print(f"{aw} {fb} {layer_name} distribution finish!")
+
+ exit(0)
diff --git a/llava/model/realquantize/common.py b/llava/model/realquantize/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad20b543c4043b197c83c2317dae26e9f2589c41
--- /dev/null
+++ b/llava/model/realquantize/common.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+
+SCALE_MIN_THRES = 1e-10
+
+FP8_MAX_VALUE = {
+ torch.float8_e4m3fn: 448,
+ torch.float8_e5m2: 57344,
+}
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+convert_fp8_to_embit = {
+ torch.float8_e4m3fn: (4.0, 3.0),
+ torch.float8_e5m2: (5.0, 2.0),
+}
+
+# from .common import SCALE_MIN_THRES, FP8_MAX_VALUE
+# SCALE_MIN_THRES: tl.constexpr,
+# + SCALE_MIN_THRES
+# SCALE_MIN_THRES=SCALE_MIN_THRES,
diff --git a/llava/model/realquantize/division.py b/llava/model/realquantize/division.py
new file mode 100644
index 0000000000000000000000000000000000000000..908b9bfc042782e16ed721d7333a3e849c227d1f
--- /dev/null
+++ b/llava/model/realquantize/division.py
@@ -0,0 +1,288 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+except:
+ from common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses full-precision/BF16"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16]:
+ if block_m == 64 and block_n == 64:
+ continue
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_kernel(
+ output_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr, boundary_check=(0, 1))
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.div_rn(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ noise_block_ptr = tl.make_block_ptr(
+ base=noise_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ noise = tl.load(noise_block_ptr, boundary_check=(0, 1))
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+@triton.jit
+def _stochastic_rounding(output, noise, e_bit, m_bit):
+ subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit)
+ # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1)
+
+ output_int32 = tl.cast(output, tl.int32, bitcast=True)
+ output_int32 = output_int32 & 0x7F800000
+ output_float32 = tl.cast(output_int32, tl.float32, bitcast=True)
+ output_exp = tl.maximum(output_float32, subnormal_min)
+
+ noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * (
+ 1 - tl.exp2(m_bit)
+ ) # 2^m_bit for normal, 1 for subnormal
+
+ noise = output_exp * noise / noise_rescale
+ sign = 1 - 2 * libdevice.signbit(output)
+ output = tl.abs(output) + noise
+
+ # tl.device_print("out", output)
+ # tl.device_print("noise", noise)
+
+ minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal
+ output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp)
+
+ return output
+
+
+def fp8_division(x, QB, fp8type, s_y=None, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_kernel[grid](
+ y,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y # y_t is expected to be 2D tensor
+
+
+# I change the dtype of both the input tensor and the output tensor. I use torch.float32, torch.float16, and torch.fp8
+
+configs = []
+for SL in [8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="provider",
+ line_vals=["triton", "torch"],
+ line_names=["triton", "torch"],
+ styles=[("blue", "-"), ("green", "-")],
+ ylabel="time-cost",
+ plot_name=f"FP8gelu",
+ args={"BS": 4, "SL": SL, "QB": 16, "fp8type": torch.float8_e4m3fn, "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ BS, SL, CDIM, QB, fp8type, provider, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(BS, SL, CDIM).cuda()
+ _qx = x.reshape(BS, SL, CDIM // QB, QB)
+ sx = _qx.abs().amax(dim=(3)) / FP8_MAX_VALUE[fp8type]
+ sx = sx.to(torch.bfloat16)
+ _qx = (_qx / sx.unsqueeze(3)).to(fp8type)
+ qx = _qx.reshape(BS, SL, CDIM)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ fp8_division(qx, sx, QB)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.SiLU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+def validity_check(BS, SL, CDIM, QB, fp8type=torch.float8_e4m3fn):
+ # create data
+ x = torch.randn(BS * SL, CDIM).cuda()
+
+ # torch result
+ avg_output_triton = torch.zeros_like(x)
+
+ # triton result
+ for _ in range(100):
+ x_triton, s_triton, x_triton_t = fp8_division(x, QB, "E4M3", stochastic=False)
+ output_triton = x_triton.float() * s_triton
+
+ avg_output_triton = avg_output_triton + output_triton
+ avg_output_triton /= 100
+
+ import IPython
+
+ IPython.embed()
diff --git a/llava/model/realquantize/division_transpose.py b/llava/model/realquantize/division_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..f75e5173f7f1b8629a001f57b1fb89e76cdb1bc7
--- /dev/null
+++ b/llava/model/realquantize/division_transpose.py
@@ -0,0 +1,311 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+ from .division import _stochastic_rounding
+except:
+ from common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+ from division import _stochastic_rounding
+
+
+"""Quantize and Transpose Operator"""
+"""Input uses full-precision/BF16"""
+"""Output1 uses per-tensor quantization"""
+"""Output2 uses per-tensor quantization and is transposed"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)
+ # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], #
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_transpose_kernel(
+ output_ptr,
+ output_t_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ output_t_stride_0,
+ output_t_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr, boundary_check=(0, 1))
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ noise_block_ptr = tl.make_block_ptr(
+ base=noise_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ noise = tl.load(noise_block_ptr, boundary_check=(0, 1))
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+ # tl.device_print("3: ", output)
+ output_t = tl.trans(output)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ output_t_block_ptr = tl.make_block_ptr(
+ base=output_t_ptr,
+ shape=(N, M),
+ strides=(output_t_stride_0, output_t_stride_1),
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
+ block_shape=(BLOCK_N, BLOCK_M),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+ tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1))
+
+
+def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ y_t = torch.empty((N, M), dtype=fp8type, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_transpose_kernel[grid](
+ y,
+ y_t,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ y_t.stride(0),
+ y_t.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y, y_t # y_t is expected to be 2D tensor
+
+
+# I change the dtype of both the input tensor and the output tensor. I use torch.float32, torch.float16, and torch.fp8
+
+configs = []
+for SL in [1024, 2048, 4096, 8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="provider",
+ line_vals=["triton", "torch"],
+ line_names=["triton", "torch"],
+ styles=[("blue", "-"), ("green", "-")],
+ ylabel="time-cost",
+ plot_name=f"FP8gelu",
+ args={"BS": 4, "SL": SL, "QB": 16, "fp8type": torch.float8_e4m3fn, "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ BS, SL, CDIM, QB, fp8type, provider, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(BS, SL, CDIM).cuda()
+ _qx = x.reshape(BS, SL, CDIM // QB, QB)
+ sx = _qx.abs().amax(dim=(3)) / FP8_MAX_VALUE[fp8type]
+ sx = sx.to(torch.bfloat16)
+ _qx = (_qx / sx.unsqueeze(3)).to(fp8type)
+ qx = _qx.reshape(BS, SL, CDIM)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ fp8_division_transpose(x, QB, fp8type)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.SiLU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=10)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=10)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+def validity_check(BS, SL, CDIM, QB, fp8type=torch.float8_e4m3fn):
+ # create data
+ # x = torch.randn(BS * SL, CDIM).cuda()
+ x = torch.tensor(
+ [
+ [
+ -4.65823793,
+ 0.33293918,
+ 0.33293918,
+ 0.00003,
+ -4.65823793,
+ 0.33293918,
+ 0.33293918,
+ 0.00003,
+ -4.65823793,
+ 0.33293918,
+ 0.33293918,
+ 0.00003,
+ -4.65823793,
+ 0.33293918,
+ 0.33293918,
+ 0.00003,
+ ]
+ ],
+ device="cuda",
+ )
+
+ # torch result
+ avg_output_triton = torch.zeros_like(x)
+ avg_output_triton_t = torch.zeros_like(x)
+
+ # triton result
+ for _ in range(1000):
+ x_triton, s_triton, x_triton_t = fp8_division_transpose(x, QB, "E4M3", stochastic=True)
+
+ output_triton = x_triton.float() * s_triton
+ output_triton_t = x_triton_t.float().t() * s_triton
+
+ avg_output_triton = avg_output_triton + output_triton
+ avg_output_triton_t = avg_output_triton_t + output_triton_t
+ avg_output_triton /= 1000
+ avg_output_triton_t /= 1000
+
+ xx, ss, xxtt = fp8_division_transpose(x, QB, "E4M3", stochastic=False)
+ import IPython
+
+ IPython.embed()
diff --git a/llava/model/realquantize/linear.py b/llava/model/realquantize/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ca1c883dcef4c3395652cb5eba8750a78c99913
--- /dev/null
+++ b/llava/model/realquantize/linear.py
@@ -0,0 +1,389 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+ from .division import _stochastic_rounding
+except:
+ from common import SCALE_MIN_THRES, FP8_MAX_VALUE, convert_str_to_fp8, convert_fp8_to_embit
+ from division import _stochastic_rounding
+
+import os
+import time
+
+"""Linear Layer Forward + Backward"""
+"""Input uses per-tensor quantization"""
+"""Output is full-precision/BF16 (for FlashAttention) or 1 * 16 quantization (for the rest)"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+def get_configs_io_block():
+ configs = []
+ for nstages in [3]:
+ for block_m in [128, 256]:
+ for block_n in [128, 256]:
+ for block_k in [128, 256]:
+ for nwarps in [8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=get_configs_io_block(),
+ key=["N"],
+)
+@triton.jit
+def _fp8matmul_kernel(
+ A,
+ B,
+ C,
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ K, #
+ stride_am,
+ stride_ak, #
+ stride_bk,
+ stride_bn, #
+ stride_cm,
+ stride_cn, ##
+ Scale_A,
+ Scale_B,
+ Scale_C,
+ stride_scm,
+ stride_scn,
+ output_quantize: tl.constexpr,
+ QB: tl.constexpr, # default to use 1 * 16 quantization
+ BIAS,
+ fp8_max,
+ e_bit,
+ m_bit,
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ grid_m = tl.cdiv(M, BLOCK_M)
+ grid_n = tl.cdiv(N, BLOCK_N)
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ # a = tl.load(A)
+ # b = tl.load(B)
+ k_remaining = K - k * BLOCK_K
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
+
+ acc = tl.dot(a, b, acc)
+
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ scale_a = tl.load(Scale_A)
+ scale_b = tl.load(Scale_B)
+ scale_ab = scale_a.to(tl.float32) * scale_b.to(tl.float32)
+ # fp8 dequantize
+ acc = acc * scale_ab
+
+ if BIAS:
+ bias = tl.load(BIAS + rbn)
+ acc = acc + bias
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+
+ if output_quantize:
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N // QB, QB))
+ abs_acc = tl.abs(acc)
+ acc_max = tl.max(abs_acc, axis=2) + SCALE_MIN_THRES
+ # tl.device_print("acc_max", acc_max)
+ acc_scale = acc_max / fp8_max
+ # tl.device_print("acc_scale", acc_scale)
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB, 1))
+ acc = tl.div_rn(acc, acc_scale)
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ noise_block_ptr = noise_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ noise = tl.load(noise_block_ptr, boundary_check=(0, 1))
+ acc = _stochastic_rounding(acc, noise, e_bit, m_bit)
+
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB))
+ acc = acc.to(C.dtype.element_ty)
+
+ rsm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rsn = pid_n * BLOCK_N // QB + tl.arange(0, BLOCK_N // QB)
+ Scale_C = Scale_C + (rsm[:, None] * stride_scm + rsn[None, :] * stride_scn)
+
+ tl.store(C, acc, mask=mask, boundary_check=(0, 1))
+ tl.store(Scale_C, acc_scale)
+
+ else:
+ # handles write-back with reduction-splitting
+ acc = acc.to(C.dtype.element_ty)
+ tl.store(C, acc, mask=mask)
+
+
+def fp8matmul(a, b, output_quantize, scale_a, scale_b, QB, bias=None, stochastic=False):
+ # Deal with batched input
+ if len(a.shape) == 3:
+ BS, batched = a.shape[0], True
+ a = a.reshape(-1, a.shape[2])
+ else:
+ batched = False
+
+ # Check constraints.
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
+ assert a.is_contiguous(), "Matrix A must be contiguous"
+ M, K = a.shape
+ K, N = b.shape
+ fp8MaxValue = FP8_MAX_VALUE[a.dtype] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[a.dtype]
+
+ # Allocates output.
+ if output_quantize:
+ c = torch.empty((M, N), device=a.device, dtype=a.dtype)
+ # c = torch.empty((M, N), device=a.device, dtype=torch.float32)
+ scale_c = torch.empty((M, N // QB), device=a.device, dtype=torch.float32)
+ else:
+ c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
+ scale_c = torch.empty(
+ (1, 1), device=a.device, dtype=torch.bfloat16
+ ) # This line is useless, equivalent to scale_c = None
+
+ if stochastic:
+ noise = torch.empty_like(c, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+ _fp8matmul_kernel[grid](
+ a,
+ b,
+ c,
+ noise, #
+ M,
+ N,
+ K, #
+ a.stride(0),
+ a.stride(1), #
+ b.stride(0),
+ b.stride(1), #
+ c.stride(0),
+ c.stride(1), #
+ scale_a,
+ scale_b,
+ scale_c,
+ scale_c.stride(0),
+ scale_c.stride(1),
+ output_quantize=output_quantize,
+ QB=QB,
+ BIAS=bias,
+ fp8_max=fp8MaxValue,
+ e_bit=e_bit,
+ m_bit=m_bit,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ # BLOCK_M=128,
+ # BLOCK_N=256,
+ # BLOCK_K=128,
+ GROUP_M=8,
+ # num_stages=3,
+ # num_warps=8,
+ )
+ # Reshape output to batch
+ if batched:
+ c = c.reshape(BS, -1, N)
+ if output_quantize:
+ scale_c = scale_c.reshape(BS, -1, N // QB)
+ return c, scale_c
+ else:
+ if output_quantize:
+ scale_c = scale_c.reshape(M, N // QB)
+ return c, scale_c
+ return c
+
+
+def fp8_linear_forward(x, s, w, s_w, output_quantize, QB, bias=None):
+ w_t = w.t()
+ return fp8matmul(x, w_t, output_quantize, s, s_w, QB, bias)
+
+
+# def fp8_linear_forward(x, s, w, s_w, output_quantize, QB):
+# print("you are using the wrong linear function. ")
+# w_t = w.t()
+# if output_quantize:
+# return fp8matmul(x, w_t, True, s, s_w, QB)
+# else:
+# y = fp8matmul(x, w_t, False, s, s_w, QB)
+
+# return y
+
+
+def fp8_linear_backward(
+ x_t, s, g, s_g, g_t, w_t, s_w, QB, bias=None, stochastic=False, dgrad_quantize=True
+): # dgrad_quantize=True for backward before flashattention
+ batched = False
+ if len(g.shape) == 3: # others must be of 2D!
+ batched = True
+ BS = g.shape[0]
+ g = g.reshape(-1, g.shape[-1])
+
+ w_t_t = w_t.t()
+ x_t_t = x_t.t()
+ if dgrad_quantize:
+ y, s_y = fp8matmul(g, w_t_t, True, s_g, s_w, QB, stochastic=stochastic)
+ else:
+ y = fp8matmul(g, w_t_t, False, s_g, s_w, QB)
+
+ w_g = fp8matmul(g_t, x_t_t, False, s_g, s, QB)
+
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ if dgrad_quantize:
+ if s_y.numel() > 1:
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+ if dgrad_quantize:
+ return y, s_y, w_g
+ else:
+ return y, w_g
+
+
+if __name__ == "__main__":
+
+ # Input = torch.load("/home/hxi/lustre_hxi/workdir/FP8_OLMo/debug_linear.pt")
+ # mul_x_t, mul_s, out_g, out_gs, out_g_t, weight2_t, weight2_s, qgroup_size = Input
+
+ # fc2_g, fc2_gs, weight2_grad = fp8_linear_backward(mul_x_t, mul_s, out_g, out_gs, out_g_t, weight2_t, weight2_s, qgroup_size, stochastic=True)
+
+ # # fc2_x = fp8_linear_forward(flash_x, flash_s, weight2, weight2_s, False, 16)
+ # import IPython
+ # IPython.embed()
+
+ def validity_check(M, N, K):
+ a = torch.randn((M, K), device="cuda", dtype=torch.float32)
+ b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16)
+
+ scale_a, scale_b = torch.randn((1), device="cuda", dtype=torch.bfloat16), torch.randn(
+ (1), device="cuda", dtype=torch.bfloat16
+ )
+ a = a.to(torch.float8_e4m3fn)
+ b = b.T
+ b = b.to(torch.float8_e4m3fn)
+
+ output_fp8_y, output_fp8_s = fp8matmul(a, b, True, scale_a, scale_b, 16)
+ a_32, b_32 = a.to(torch.float32), b.to(torch.float32)
+ output_torch = torch.matmul(a_32, b_32) * scale_a * scale_b
+
+ import IPython
+
+ IPython.embed()
+
+ def time_check(M, N, K):
+ a = torch.randn((M, K), device="cuda", dtype=torch.float32)
+ b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16)
+
+ scale_a, scale_b = torch.randn((1), device="cuda", dtype=torch.bfloat16), torch.randn(
+ (1), device="cuda", dtype=torch.bfloat16
+ )
+ a = a.to(torch.float8_e4m3fn)
+ b = b.T
+ b = b.to(torch.float8_e4m3fn)
+
+ for _ in range(10):
+ torch.cuda.synchronize()
+ start = time.time()
+ output_fp8_y = fp8matmul(a, b, False, scale_a, scale_b, 16)
+ torch.cuda.synchronize()
+ end = time.time()
+ print(end - start)
+
+ # import IPython
+ # IPython.embed()
+
+ configs = []
+ for fp8_inputs in [True]:
+ configs.append(
+ triton.testing.Benchmark(
+ x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot
+ x_vals=[512 * i for i in range(2, 17)], # Different possible values for `x_name`
+ line_arg="provider", # Argument name whose value corresponds to a different line in the plot
+ # Possible values for `line_arg`
+ # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
+ line_vals=["triton"] if fp8_inputs else ["cublas", "triton"], # Label name for the lines
+ line_names=["Triton"] if fp8_inputs else ["cuBLAS", "Triton"], # Line styles
+ styles=[("green", "-"), ("blue", "-")],
+ ylabel="TFLOPS", # Label name for the y-axis
+ plot_name="matmul-performance-"
+ + (
+ "fp16" if not fp8_inputs else "fp8"
+ ), # Name for the plot, used also as a file name for saving the plot.
+ args={"fp8_inputs": fp8_inputs},
+ )
+ )
+
+ @triton.testing.perf_report(configs)
+ def benchmark(M, N, K, provider, fp8_inputs):
+ a = torch.randn((M, K), device="cuda", dtype=torch.bfloat16)
+ b = torch.randn((N, K), device="cuda", dtype=torch.bfloat16)
+ if fp8_inputs:
+ a = a.to(torch.float8_e4m3fn)
+ b = b.T
+ b = b.to(torch.float8_e4m3fn)
+ scale_a, scale_b = torch.randn((1), device="cuda", dtype=torch.bfloat16), torch.randn(
+ (1), device="cuda", dtype=torch.bfloat16
+ )
+ quantiles = [0.5, 0.2, 0.8]
+ if provider == "cublas":
+ import IPython
+
+ IPython.embed()
+ ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
+ if provider == "triton":
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: fp8matmul(a, b, False, scale_a, scale_b, 16), quantiles=quantiles
+ )
+ perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
+ return perf(ms), perf(max_ms), perf(min_ms)
+
+ torch.set_printoptions(sci_mode=False, linewidth=200, precision=6)
+ # time_check(4096, 11008, 5380)
+ # validity_check(2048, 1024, 4096)
+ benchmark.run(show_plots=True, print_data=True)
diff --git a/llava/model/realquantize/quantize_and_transpose.py b/llava/model/realquantize/quantize_and_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..5df117fc47bb23c9af99870aef8b2cf2a8195451
--- /dev/null
+++ b/llava/model/realquantize/quantize_and_transpose.py
@@ -0,0 +1,239 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES
+from .division_transpose import fp8_division_transpose
+
+"""Quantize and Transpose Operator"""
+"""Input uses floating point tensor"""
+"""Output uses per-tensor quantization, returns a non-transpose version and a transpose version"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_and_transpose_kernel(
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr, boundary_check=(0, 1))
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_and_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.float32, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_and_transpose_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(
+ x, QB, fp8type, s_y_max, stochastic=stochastic
+ ) # Stochastic Rounding happens here
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1])
+
+ return qy, s_y_max, qy_t # y_t is expected to be 2D tensor
+
+
+# I change the dtype of both the input tensor and the output tensor. I use torch.float32, torch.float16, and torch.fp8
+
+configs = []
+for SL in [8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="provider",
+ line_vals=["triton", "torch"],
+ line_names=["triton", "torch"],
+ styles=[("blue", "-"), ("green", "-")],
+ ylabel="time-cost",
+ plot_name=f"FP8gelu",
+ args={"BS": 4, "SL": SL, "QB": 16, "fp8type": torch.float8_e4m3fn, "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ BS, SL, CDIM, QB, fp8type, provider, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(BS, SL, CDIM).cuda()
+ _qx = x.reshape(BS, SL, CDIM // QB, QB)
+ sx = _qx.abs().amax(dim=(3)) / FP8_MAX_VALUE[fp8type]
+ sx = sx.to(torch.bfloat16)
+ _qx = (_qx / sx.unsqueeze(3)).to(fp8type)
+ qx = _qx.reshape(BS, SL, CDIM)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ fp8_quantize_and_transpose(qx, sx, QB)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.SiLU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+def validity_check(BS, SL, CDIM, QB, fp8type=torch.float8_e4m3fn):
+ # create data
+ x = torch.randn(BS * SL, CDIM).cuda()
+
+ # torch result
+
+ # triton result
+ x_triton, s_triton, x_triton_t = fp8_quantize_and_transpose(x, QB, "E4M3")
+
+ _x_triton = x_triton.reshape(BS * SL, CDIM // QB, QB)
+ _x_triton = _x_triton.to(torch.float32)
+ s_triton = s_triton.unsqueeze(2)
+ output_triton = (_x_triton * s_triton).reshape(BS * SL, CDIM)
+
+ import IPython
+
+ IPython.embed()
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3)
+ validity_check(BS=4, SL=256, CDIM=512, QB=16, fp8type=torch.float8_e4m3fn)
+ bench_load_store.run(save_path=f"result/time/multi_quantize_block_quantize/BLSZ=64", print_data=True)
diff --git a/llava/model/realquantize/trans_grad_bias.py b/llava/model/realquantize/trans_grad_bias.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf06d31575198396bbe522291fbf8f263050e3ab
--- /dev/null
+++ b/llava/model/realquantize/trans_grad_bias.py
@@ -0,0 +1,238 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES
+
+"""Calculate the gradient of bias Operator"""
+"""Input uses per-tensor quantization, and should be transposed"""
+"""Output uses similar to the bias shape"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_trans_grad_bias_kernel(
+ output_scale_ptr, # output
+ input_t_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr, boundary_check=(0, 1))
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_and_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.float32, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_and_transpose_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(
+ x, QB, fp8type, s_y_max, stochastic=stochastic
+ ) # Stochastic Rounding happens here
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1])
+
+ return qy, s_y_max, qy_t # y_t is expected to be 2D tensor
+
+
+# I change the dtype of both the input tensor and the output tensor. I use torch.float32, torch.float16, and torch.fp8
+
+configs = []
+for SL in [8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="provider",
+ line_vals=["triton", "torch"],
+ line_names=["triton", "torch"],
+ styles=[("blue", "-"), ("green", "-")],
+ ylabel="time-cost",
+ plot_name=f"FP8gelu",
+ args={"BS": 4, "SL": SL, "QB": 16, "fp8type": torch.float8_e4m3fn, "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ BS, SL, CDIM, QB, fp8type, provider, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(BS, SL, CDIM).cuda()
+ _qx = x.reshape(BS, SL, CDIM // QB, QB)
+ sx = _qx.abs().amax(dim=(3)) / FP8_MAX_VALUE[fp8type]
+ sx = sx.to(torch.bfloat16)
+ _qx = (_qx / sx.unsqueeze(3)).to(fp8type)
+ qx = _qx.reshape(BS, SL, CDIM)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ fp8_quantize_and_transpose(qx, sx, QB)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.SiLU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+def validity_check(BS, SL, CDIM, QB, fp8type=torch.float8_e4m3fn):
+ # create data
+ x = torch.randn(BS * SL, CDIM).cuda()
+
+ # torch result
+
+ # triton result
+ x_triton, s_triton, x_triton_t = fp8_quantize_and_transpose(x, QB, "E4M3")
+
+ _x_triton = x_triton.reshape(BS * SL, CDIM // QB, QB)
+ _x_triton = _x_triton.to(torch.float32)
+ s_triton = s_triton.unsqueeze(2)
+ output_triton = (_x_triton * s_triton).reshape(BS * SL, CDIM)
+
+ import IPython
+
+ IPython.embed()
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3)
+ validity_check(BS=4, SL=256, CDIM=512, QB=16, fp8type=torch.float8_e4m3fn)
+ bench_load_store.run(save_path=f"result/time/multi_quantize_block_quantize/BLSZ=64", print_data=True)
diff --git a/llava/model/utils/__init__.py b/llava/model/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76ccca7e49181c27dce49da4f49de8968034d63c
--- /dev/null
+++ b/llava/model/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .utils import *
diff --git a/llava/model/utils/packing.py b/llava/model/utils/packing.py
new file mode 100644
index 0000000000000000000000000000000000000000..767237fc7214378132ba6416bc0fbb6857703329
--- /dev/null
+++ b/llava/model/utils/packing.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from importlib import import_module
+from typing import Tuple
+
+import torch
+import transformers
+from torch import nn
+from torch.nn import functional as F
+
+__all__ = ["patch"]
+
+
+def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ if hasattr(_get_unpad_data, "seqlens_in_batch"):
+ seqlens_in_batch = _get_unpad_data.seqlens_in_batch
+ else:
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
+
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return indices, cu_seqlens, max_seqlen_in_batch
+
+
+def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
+ _get_unpad_data.seqlens_in_batch = seqlens_in_batch
+
+
+def patch(model: nn.Module) -> None:
+ if transformers.__version__ < "4.43.0":
+ m = import_module(model.__module__)
+ if not hasattr(m, "_get_unpad_data"):
+ raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
+ m._get_unpad_data = _get_unpad_data
+ else:
+ transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
diff --git a/llava/model/utils/utils.py b/llava/model/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c528f32f1cab041c67becbd79612054b3c9c06
--- /dev/null
+++ b/llava/model/utils/utils.py
@@ -0,0 +1,178 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+import os
+import os.path as osp
+
+from huggingface_hub import repo_exists, snapshot_download
+from huggingface_hub.utils import HFValidationError, validate_repo_id
+from transformers import AutoConfig, PretrainedConfig
+
+
+def get_model_config(config):
+ default_keys = ["llm_cfg", "vision_tower_cfg", "speech_tower_cfg","sound_tower_cfg", "mm_projector_cfg", "speech_mm_projector_cfg", "sound_mm_projector_cfg"]
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
+ root_path = config._name_or_path
+ else:
+ root_path = config.resume_path
+ # download from huggingface
+ if root_path is not None and not osp.exists(root_path):
+ try:
+ valid_hf_repo = repo_exists(root_path)
+ except HFValidationError as e:
+ valid_hf_repo = False
+ if valid_hf_repo:
+ root_path = snapshot_download(root_path)
+
+ return_list = []
+ for key in default_keys:
+ cfg = getattr(config, key, None)
+ if isinstance(cfg, dict):
+ try:
+ return_list.append(os.path.join(root_path, key[:-4]))
+ except:
+ raise ValueError(f"Cannot find resume path in config for {key}!")
+ elif isinstance(cfg, PretrainedConfig):
+ return_list.append(os.path.join(root_path, key[:-4]))
+ elif isinstance(cfg, str):
+ return_list.append(cfg)
+
+ return return_list
+
+
+def get_model_config_fp8(config):
+ default_keys = ["llm_cfg", "vision_tower_cfg", "speech_tower_cfg","sound_tower_cfg", "mm_projector_cfg", "speech_mm_projector_cfg", "sound_mm_projector_cfg"]
+
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
+ root_path = config._name_or_path
+ else:
+ root_path = config.resume_path
+
+ # download from huggingface
+ if root_path is not None and not osp.exists(root_path):
+ try:
+ valid_hf_repo = repo_exists(root_path)
+ except HFValidationError as e:
+ valid_hf_repo = False
+ if valid_hf_repo:
+ root_path = snapshot_download(root_path)
+
+ return_list = []
+ for key in default_keys:
+ cfg = getattr(config, key, None)
+ if isinstance(cfg, dict):
+ try:
+ return_list.append(os.path.join(root_path, key[:-4]))
+ except:
+ raise ValueError(f"Cannot find resume path in config for {key}!")
+ elif isinstance(cfg, PretrainedConfig):
+ return_list.append(os.path.join(root_path, key[:-4]))
+ elif isinstance(cfg, str):
+ return_list.append(cfg)
+
+ # fp8_llm
+ key = "fp8_llm_cfg"
+ directory_path = os.path.join(root_path, key[:-4])
+ assert os.path.isdir(directory_path) and os.listdir(
+ directory_path
+ ), "You need to first convert the model weights to FP8 explicitly."
+ return_list.append(directory_path)
+
+ return return_list
+
+
+def get_model_config_fp8(config):
+ default_keys = ["llm_cfg", "vision_tower_cfg", "speech_tower_cfg","sound_tower_cfg", "mm_projector_cfg", "speech_mm_projector_cfg", "sound_mm_projector_cfg"]
+
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
+ root_path = config._name_or_path
+ else:
+ root_path = config.resume_path
+
+ # download from huggingface
+ if root_path is not None and not osp.exists(root_path):
+ try:
+ valid_hf_repo = repo_exists(root_path)
+ except HFValidationError as e:
+ valid_hf_repo = False
+ if valid_hf_repo:
+ root_path = snapshot_download(root_path)
+
+ return_list = []
+ for key in default_keys:
+ cfg = getattr(config, key, None)
+ if isinstance(cfg, dict):
+ try:
+ return_list.append(os.path.join(root_path, key[:-4]))
+ except:
+ raise ValueError(f"Cannot find resume path in config for {key}!")
+ elif isinstance(cfg, PretrainedConfig):
+ return_list.append(os.path.join(root_path, key[:-4]))
+ elif isinstance(cfg, str):
+ return_list.append(cfg)
+
+ # fp8_llm
+ key = "fp8_llm_cfg"
+ directory_path = os.path.join(root_path, key[:-4])
+ assert os.path.isdir(directory_path) and os.listdir(
+ directory_path
+ ), "You need to first convert the model weights to FP8 explicitly."
+ return_list.append(directory_path)
+
+ return return_list
+
+
+def is_mm_model(model_path):
+ """
+ Check if the model at the given path is a visual language model.
+
+ Args:
+ model_path (str): The path to the model.
+
+ Returns:
+ bool: True if the model is an MM model, False otherwise.
+ """
+ config = AutoConfig.from_pretrained(model_path)
+ architectures = config.architectures
+ for architecture in architectures:
+ if "llava" in architecture.lower():
+ return True
+ return False
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if "llava" in config and "llava" not in cfg.model_type:
+ assert cfg.model_type == "llama"
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "llava")
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
diff --git a/llava/train/__init__.py b/llava/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69923298db60fa8275bf22895fd5c54cf102e9b9
--- /dev/null
+++ b/llava/train/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
diff --git a/llava/train/args.py b/llava/train/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7cbffcd92a9dc50ffc9d58b6b10c0596c827045
--- /dev/null
+++ b/llava/train/args.py
@@ -0,0 +1,293 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+import transformers
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: Optional[str] = "resize"
+ min_tiles: Optional[int] = 1
+ max_tiles: Optional[int] = 12
+ video_max_tiles: Optional[int] = 1 # value larger than 1 means we're training w/ tiling for videos.
+ audio_frames: Optional[int] = 5
+ data_mixture: str = "llava_1_5_mm_align"
+ eval_data_mixture: str = None
+ vflan_no_system_prompt: bool = False
+ downsample_video: bool = False
+
+ # for video training
+ num_video_frames: int = 8
+ fps: float = 0.0 # 0.0 means we do not use fps at all. Always sample the same number of frames.
+
+
+@dataclass
+class ModelArguments:
+ version: Optional[str] = field(default="auto")
+ chat_template: Optional[str] = field(default=None)
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ vision_tower: Optional[str] = field(default="google/siglip-so400m-patch14-384")
+ speech_tower: Optional[str] = field(default="openai/whisper-large-v2")
+ sound_tower: Optional[str] = field(default="imagebind_huge.pth")
+ mm_projector: Optional[str] = field(default="mlp2x_gelu")
+ speech_mm_projector: Optional[str] = field(default="mlp")
+ sound_mm_projector: Optional[str] = field(default="mlp")
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=False)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+ vision_resolution: Optional[int] = field(default=-1)
+ interpolate_mode: Optional[str] = field(default="linear")
+ drop_path_rate: Optional[float] = field(default=0.0)
+ mlp_path: Optional[str] = field(default=None)
+ s2: bool = field(default=False)
+ dynamic_s2: bool = field(default=False)
+ s2_scales: Optional[str] = field(default="336,672,1008")
+ s2_max_split_size: int = field(default=336)
+ num_time_tokens: int = field(default=0)
+ time_token_format: str = field(default="")
+ soft_ce_std: float = field(default=1.0)
+
+ image_encoder: str = field(default='{"_target_": "llava.model.encoders.BasicImageEncoder"}')
+ video_encoder: str = field(default='{"_target_": "llava.model.encoders.BasicVideoEncoder"}')
+ speech_encoder: str = field(default='{"_target_": "llava.model.encoders.BasicSpeechEncoder"}')
+ sound_encoder: str = field(default='{"_target_": "llava.model.encoders.BasicSoundEncoder"}')
+
+ s2_resize_output_to_scale_idx: int = field(default=0)
+
+ # Quantization and low precision training
+ quantize_model: Optional[str] = field(default="false")
+ symm: Optional[bool] = field(default=True)
+
+ epsilon: Optional[float] = field(default=1e-10)
+ fabit: Optional[str] = field(default="E4M3")
+ fwbit: Optional[str] = field(default="E4M3")
+ bobit: Optional[str] = field(default="E5M2")
+ row_blocksize: Optional[int] = -1 # -1 means only 1 quantization group along row axis
+ col_blocksize: Optional[int] = -1 # -1 means only 1 quantization group along column axis
+ qchoice: Optional[list[str]] = field(
+ default_factory=lambda: [
+ "none",
+ "all",
+ "linear",
+ "mlp",
+ "attn",
+ "gelu",
+ "layernorm",
+ "backbone",
+ "residual",
+ "backbone",
+ ],
+ )
+
+ pad_to_multiple_of: int = 0 # if sequence length * batch size can not be divided by 128, the triton implementation of fp8 matmul when calculating weight gradient will become highly inefficient. Therefore, I want to pad the sequence length to a multiple of some exponent of 2. This will be used in prepare_inputs_labels_for_multimodal()
+
+ # Memory Efficient FP8 related
+ Ubit: str = field(default="100")
+ quantize_model: str = field(default="false", metadata={"help": "Enable model quantization"})
+ symm: bool = field(default=True, metadata={"help": "Use symmetric quantization"})
+ epsilon: float = field(default=1e-10, metadata={"help": "Small epsilon for numerical stability"})
+ fabit: str = field(default="E4M3", metadata={"help": "Bit format for forward activation"})
+ fwbit: str = field(default="E4M3", metadata={"help": "Bit format for forward weights"})
+ fobit: str = field(default="E4M3", metadata={"help": "Bit format for forward output"})
+ babit: str = field(default="E5M2", metadata={"help": "Bit format for backward activation"})
+ bwbit: str = field(default="E5M2", metadata={"help": "Bit format for backward weights"})
+ bobit: str = field(default="E5M2", metadata={"help": "Bit format for backward output"})
+ qchoice: str = field(default="none", metadata={"help": "Quantization choice"})
+ group_size: int = field(default=-1, metadata={"help": "Group size for quantization"})
+ weight_memory_efficient: bool = field(default=True, metadata={"help": "Enable memory-efficient weights"})
+
+ min_blockunit_row: int = field(default=4)
+ min_blockunit_col: int = field(default=4)
+ refine_residual_fp: bool = field(default=False)
+ refine_ln_pertoken: bool = field(default=False)
+ refine_ln_blocksize: bool = field(default=False)
+ refine_ln_blocksize_but_only_forward: bool = field(default=False)
+ refine_ln_blocksize_but_only_backward: bool = field(default=False)
+ refine_attn_blocksize: bool = field(default=False)
+ refine_mlp_blocksize: bool = field(default=False)
+ refine_row_blocksize: int = field(default=4)
+ refine_col_blocksize: int = field(default=4)
+ draw_distribution_forward: bool = field(default=False)
+ draw_distribution_backward: bool = field(default=False)
+
+ # Quantize Optimizer Related
+ use_quantize_optimizer: bool = field(default=False)
+ row_blocksize_optimizer: int = field(default=1)
+ col_blocksize_optimizer: int = field(default=128)
+ pad_block: bool = field(default=False)
+ first_order_bit: Optional[str] = field(default=None)
+ first_order_quant_type: Optional[str] = field(default=None)
+ second_order_bit: Optional[str] = field(default=None)
+ second_order_quant_type: Optional[str] = field(default=None)
+ epsilon_optimizer: float = field(default=1e-15)
+
+ # Quantization and low precision training
+ quantize_model: Optional[str] = field(default="false")
+ symm: Optional[bool] = field(default=True)
+
+ epsilon: Optional[float] = field(default=1e-10)
+ fabit: Optional[str] = field(default="E4M3")
+ fwbit: Optional[str] = field(default="E4M3")
+ bobit: Optional[str] = field(default="E5M2")
+ row_blocksize: Optional[int] = -1 # -1 means only 1 quantization group along row axis
+ col_blocksize: Optional[int] = -1 # -1 means only 1 quantization group along column axis
+ qchoice: Optional[list[str]] = field(
+ default_factory=lambda: [
+ "none",
+ "all",
+ "linear",
+ "mlp",
+ "attn",
+ "gelu",
+ "layernorm",
+ "backbone",
+ "residual",
+ "backbone",
+ ],
+ )
+
+ pad_to_multiple_of: int = 0 # if sequence length * batch size can not be divided by 128, the triton implementation of fp8 matmul when calculating weight gradient will become highly inefficient. Therefore, I want to pad the sequence length to a multiple of some exponent of 2. This will be used in prepare_inputs_labels_for_multimodal()
+
+ # Memory Efficient FP8 related
+ Ubit: str = field(default="100")
+ quantize_model: str = field(default="false", metadata={"help": "Enable model quantization"})
+ symm: bool = field(default=True, metadata={"help": "Use symmetric quantization"})
+ epsilon: float = field(default=1e-10, metadata={"help": "Small epsilon for numerical stability"})
+ fabit: str = field(default="E4M3", metadata={"help": "Bit format for forward activation"})
+ fwbit: str = field(default="E4M3", metadata={"help": "Bit format for forward weights"})
+ fobit: str = field(default="E4M3", metadata={"help": "Bit format for forward output"})
+ babit: str = field(default="E5M2", metadata={"help": "Bit format for backward activation"})
+ bwbit: str = field(default="E5M2", metadata={"help": "Bit format for backward weights"})
+ bobit: str = field(default="E5M2", metadata={"help": "Bit format for backward output"})
+ qchoice: str = field(default="none", metadata={"help": "Quantization choice"})
+ group_size: int = field(default=-1, metadata={"help": "Group size for quantization"})
+ weight_memory_efficient: bool = field(default=True, metadata={"help": "Enable memory-efficient weights"})
+
+ min_blockunit_row: int = field(default=4)
+ min_blockunit_col: int = field(default=4)
+ refine_residual_fp: bool = field(default=False)
+ refine_ln_pertoken: bool = field(default=False)
+ refine_ln_blocksize: bool = field(default=False)
+ refine_ln_blocksize_but_only_forward: bool = field(default=False)
+ refine_ln_blocksize_but_only_backward: bool = field(default=False)
+ refine_attn_blocksize: bool = field(default=False)
+ refine_mlp_blocksize: bool = field(default=False)
+ refine_row_blocksize: int = field(default=4)
+ refine_col_blocksize: int = field(default=4)
+ draw_distribution_forward: bool = field(default=False)
+ draw_distribution_backward: bool = field(default=False)
+
+ # Quantize Optimizer Related
+ use_quantize_optimizer: bool = field(default=False)
+ row_blocksize_optimizer: int = field(default=1)
+ col_blocksize_optimizer: int = field(default=128)
+ pad_block: bool = field(default=False)
+ first_order_bit: Optional[str] = field(default=None)
+ first_order_quant_type: Optional[str] = field(default=None)
+ second_order_bit: Optional[str] = field(default=None)
+ second_order_quant_type: Optional[str] = field(default=None)
+ epsilon_optimizer: float = field(default=1e-15)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ tune_vision_tower: bool = field(default=False)
+ tune_speech_tower: bool = field(default=False)
+ tune_sound_tower: bool = field(default=False)
+ tune_language_model: bool = field(default=False)
+ tune_mm_projector: bool = field(default=False)
+ tune_speech_mm_projector: bool = field(default=False)
+ tune_sound_mm_projector: bool = field(default=False)
+ model_dtype: str = field(default="torch.bfloat16")
+ model_max_length: int = field(
+ default=512,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."},
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
+ )
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
+ # lora-related
+ lora_enable: bool = False
+ use_dora: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ lora_llm: bool = False
+ lora_vt: bool = False
+ lora_st: bool = False
+ lora_sot: bool = False
+ dpo: bool = False
+ use_one_logger: bool = True
+ longvila_sampler: bool = False
+ dpo_beta: float = field(default=0.1)
+ mm_projector_lr: Optional[float] = None
+ speech_mm_projector_lr: Optional[float] = None
+ sound_mm_projector_lr: Optional[float] = None
+ vision_tower_lr: Optional[float] = None
+ speech_tower_lr: Optional[float] = None
+ sound_tower_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+ total_time_limit: int = field(default=-1, metadata={"help": "Timeout limit for this job (in minutes)."})
+ pre_terminate_time: int = field(
+ default=10,
+ metadata={"help": "Time to terminate the task inadvance (minutes), saveing checkpoints needs time."},
+ )
+ seq_parallel_size: int = field(
+ default=-1,
+ metadata={"help": "The degree of sequence parallelism (SP). SP is disabled by default (value: -1). "},
+ )
+ seq_parallel_ring_size: int = field(
+ default=-1,
+ metadata={
+ "help": "The communication process group size using optimized Ring Attention approach in SP, where `seq_parallel_size` = `seq_parallel_ring_size` x `seq_parallel_ulysses_size` (determined by other two terms). Ring Attention approach is disabled by default in SP. This setting is adjustable only when `seq_parallel_size` > 1."
+ },
+ )
+ seq_parallel_ring_type: str = field(
+ default="ring_varlen",
+ metadata={
+ "help": "Ring Attention implementation. Support ['ring_varlen', 'zigzag_ring_varlen'] in 2D attention. Only works when `seq_parallel_ring_size` > 1."
+ },
+ )
+ debug_e2e: bool = field(
+ default=False,
+ metadata={"help": "Whether enter debug mode."},
+ )
diff --git a/llava/train/callbacks/autoresume_callback.py b/llava/train/callbacks/autoresume_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..678d06fbed45be8c1e2fb9539281e46841af243d
--- /dev/null
+++ b/llava/train/callbacks/autoresume_callback.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+""" AutoResume callback.
+
+A transformer trainer callback for interfacing with ADLR's AutoResume SDK.
+
+Copyright 2024 NVIDIA CORPORATION.
+"""
+import os
+import sys
+
+import torch
+import transformers
+from transformers.utils import logging
+
+logger = logging.get_logger("transformers")
+
+
+def rank_print(*s):
+ if not torch.distributed.is_initialized():
+ rank = 0
+ else:
+ rank = torch.distributed.get_rank()
+ print(rank, *s)
+
+
+sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
+try:
+ logger.info("Importing AutoResume lib...")
+ from userlib.auto_resume import AutoResume
+
+ AutoResume.init()
+ logger.info("Found AutoResume SDK!")
+except:
+ logger.warn("Did not find AutoResume SDK!")
+ AutoResume = None
+
+
+class AutoResumeCallback(transformers.TrainerCallback):
+ """
+ A [`TrainerCallback`] that handles autoresume.
+
+ Args:
+ interval: interval (in number of iterations) between checks as to
+ whether to suspend.
+ """
+
+ def __init__(self, interval: int = 50):
+ self.interval = interval
+
+ def on_step_end(self, args, state, control, **kwargs):
+ if state.global_step % self.interval == 0:
+ rank_print("AutoResumeHook: Checking whether to suspend...")
+
+ # Check whether to suspend the job.
+ should_preempt = AutoResume is not None and AutoResume.termination_requested()
+
+ if should_preempt:
+ if state.is_local_process_zero:
+ logger.warn(f"AutoResumeHook: Request resume...")
+ if AutoResume is not None:
+ AutoResume.request_resume()
+ control.should_training_stop = True
+ control.should_save = True
diff --git a/llava/train/deepspeed_replace/runtime/zero/mics.py b/llava/train/deepspeed_replace/runtime/zero/mics.py
new file mode 100644
index 0000000000000000000000000000000000000000..95a2ddd01dcc804a69f98f449221d5ff21a810cd
--- /dev/null
+++ b/llava/train/deepspeed_replace/runtime/zero/mics.py
@@ -0,0 +1,551 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import sys
+from typing import List
+
+import deepspeed
+import torch
+from deepspeed import comm as dist
+from deepspeed.accelerator import get_accelerator
+from deepspeed.runtime.zero.mics_utils import MiCS_CommGroups, create_mics_comm_groups, scale_tensors
+from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, is_zero_param
+from deepspeed.runtime.zero.partition_parameters import AllGatherCoalescedHandle, Init, ZeroParamStatus
+from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
+from deepspeed.utils import instrument_w_nvtx, log_dist
+from torch import Tensor
+from torch.nn import Parameter
+
+
+def has_hierarchical_all_gather_groups(comm_groups: MiCS_CommGroups):
+ result = False
+ if comm_groups.param_intra_node_group is not None and comm_groups.param_inter_node_shard_group is not None:
+ result = True
+ return result
+
+
+def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
+ return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)
+
+
+class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle):
+ """This handle assumes that no need to
+ copy data out from a contiguous tensor
+ """
+
+ def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None:
+ super().__init__(allgather_handle, params, partitions, world_size)
+
+ def wait(self) -> None:
+ """ """
+ # let the current stream to op
+ instrument_w_nvtx(self.allgather_handle.wait)()
+ if self.complete:
+ return
+
+ for _, param in enumerate(self.params):
+ assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
+ param.ds_status = ZeroParamStatus.AVAILABLE
+
+ self.complete = True
+
+
+class MiCS_Init(Init):
+ def __init__(
+ self,
+ module=None,
+ data_parallel_group=None,
+ mem_efficient_linear=True,
+ remote_device=None,
+ pin_memory=False,
+ config_dict_or_path=None,
+ config=None,
+ enabled=True,
+ dtype=None,
+ mpu=None,
+ ):
+ """A context manager to partition the model parameters during the model
+ construction with MiCS partition strategy. Model states are partitioned
+ to the number of devices specified via ``mics_shard_size`` field in the
+ deepspeed config json file. The context manager also introduces
+ hierarchical communication method to reduce the cost of inter-node
+ communications, which can be enabled with
+ ``mics_hierarchical_params_gather`` field in deepspeed config.
+
+ Args:
+ module (``torch.nn.Module``, optional): If provided, partition the model as
+ if it was constructed in the context.
+ data_parallel_group (``deepspeed.comm`` process group, optional):
+ The group of processes to partition among. Defaults to all processes.
+ mem_efficient_linear (bool, optional): Replace
+ torch.nn.functional.linear with an implementation that allows
+ DeepSpeed to partition parameters. Defaults to ``True``.
+ remote_device (string, optional): The initial device to store model
+ weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
+ memory. The model may still be moved to GPU based on the
+ offload settings for training. Defaults to param offload device if a config is
+ defined, otherwise GPU.
+ pin_memory (bool, optional): Potentially increase performance by
+ using pinned memory for model weights. ``remote_device`` must be
+ ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
+ config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
+ for swapping fp16 params to NVMe.
+ config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
+ enabled (bool, optional): If ``False``, this context has no
+ effect. Defaults to ``True``.
+ dtype (``dtype``, optional): Can be used to change the data type of the parameters.
+ Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
+ mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
+
+ This context follows the same logic as ``deepspeed.zero.Init()``, but
+ with the modification for partition size of each parameter.
+
+ Examples
+ --------
+
+ #. Allocate a model and partition it among all processes:
+
+ .. code-block:: python
+ # the config_dict_or_path is required to let the context manager know
+ # how partition the parameters.
+ # The configuration has to include the field ``mics_shard_size``
+ with deepspeed.zero.MiCS_Init(config_dict_or_path=ds_config):
+ model = MyLargeModel()
+
+
+ #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
+
+ .. code-block:: python
+
+ with deepspeed.zero.MiCS_Init(data_parallel_group=mpu.get_data_parallel_group(),
+ remote_device="cpu",
+ pin_memory=True
+ config_dict_or_path=ds_config):
+ model = MyLargeModel()
+
+
+ #. Partition an already-allocated model in CPU memory:
+
+ .. code-block:: python
+
+ model = deepspeed.zero.MiCS_Init(module=model,
+ config_dict_or_path=ds_config)
+ """
+
+ assert config_dict_or_path is not None, "Must provide configuration for MiCS Initialization"
+ _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu)
+ if not dist.is_initialized():
+ dist.init_distributed()
+ assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
+ self.mics_comm_groups = create_mics_comm_groups(
+ _ds_config.mics_shard_size,
+ data_parallel_group,
+ hierarchical_allgather=_ds_config.mics_hierarchial_params_gather,
+ mpu=mpu,
+ )
+
+ super().__init__(
+ module,
+ data_parallel_group,
+ mem_efficient_linear,
+ remote_device,
+ pin_memory,
+ config_dict_or_path,
+ config,
+ enabled,
+ dtype,
+ mpu,
+ )
+
+ def _convert_to_deepspeed_param(self, param):
+ super()._convert_to_deepspeed_param(param)
+ # attach communication groups to every param
+ param.comm = self.mics_comm_groups
+
+ # record existing all_gather_coalesced implementation
+ # so that we can fallback later
+ old_all_gather_coalesced = param.all_gather_coalesced
+
+ def _param_all_gather_coalesced(params, safe_mode=False, param_buffers=None):
+ """"""
+ mics_comm_groups: MiCS_CommGroups = params[0].comm
+ hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups)
+ if dist.has_coalescing_manager() and hierarchical_all_gather:
+ return self._hierarchical_all_gather_params(params, param_buffers)
+ elif dist.has_coalescing_manager():
+ return self._flat_all_gather_with_coalescing_manager(params, param_buffers)
+ else:
+ return old_all_gather_coalesced(params, safe_mode)
+
+ # change the all_gather_coalesced method
+ param.all_gather_coalesced = _param_all_gather_coalesced
+
+ def _pre_all_gather(self, params, params_buffers=None):
+ # fetches from nvme if the partition is not available and in nvme
+ self._ensure_availability_of_partitioned_params(params)
+
+ for param in params:
+ if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
+ raise RuntimeError(param.ds_summary())
+ param.ds_status = ZeroParamStatus.INFLIGHT
+
+ # ensure that each rank has params in same order. the allgather
+ # is done by flattening the parameter list into a single tensor that
+ # can be allgathered in a single call - this means that if each rank
+ # gives a list of the same parameters in a different order we will
+ # silently get incorrect parameter values, and have very difficult
+ # to debug correctness issues.
+ params = sorted(params, key=lambda p: p.ds_id)
+ return params, params_buffers
+
+ def _flat_all_gather_with_coalescing_manager(self, params, params_buffers=None):
+ """"""
+ params, params_buffers = self._pre_all_gather(params, params_buffers)
+ mics_comm_groups: MiCS_CommGroups = params[0].comm
+ param_shard_size = mics_comm_groups.param_shard_size
+ rank_in_group = mics_comm_groups.param_shard_rank
+ partition_sz = sum(p.ds_tensor.ds_numel for p in params)
+
+ # output_tensors = []
+ # input_tensors = []
+ # for i, p in enumerate(params):
+ # t_size = p.ds_tensor.ds_numel * param_shard_size
+ # if params_buffers is not None and params_buffers[i] is not None:
+ # assert params_buffers[i].numel(
+ # ) == t_size, f'params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}'
+ # flat_out = params_buffers[i]
+ # else:
+ # flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1)
+ # # flat_out = torch.zeros(t_size, dtype=p.dtype, device=self.local_device).view(-1)
+ # output_tensors.append(flat_out)
+ # _flat_input = p.ds_tensor.data.view(-1)
+ # input_tensors.append(_flat_input)
+
+ # input_tensor = torch.cat(input_tensors, dim=0)
+ # flat_tensor = torch.cat(output_tensors, dim=0)
+ flat_tensor = torch.empty(
+ partition_sz * param_shard_size, dtype=params[0].dtype, device=self.local_device, requires_grad=False
+ ).view(-1)
+
+ partitions: List[Parameter] = []
+ for i in range(param_shard_size):
+ partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
+ instrument_w_nvtx(torch.cat)(
+ [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group]
+ )
+ # Ensure all gather output size is correct
+ assert partitions[rank_in_group].numel() * param_shard_size == flat_tensor.numel()
+ handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, mics_comm_groups.param_shard_group)
+
+ # Clean the buffer after communication
+ # torch.cuda.empty_cache()
+
+ return AllGatherCoalescedHandle(
+ allgather_handle=handle,
+ params=params,
+ partitions=partitions,
+ world_size=param_shard_size,
+ )
+
+ def _hierarchical_all_gather_params(self, params, params_buffers=None):
+ """"""
+ raise NotImplementedError("Hierarchical all-gather is not implemented yet")
+ params, params_buffers = self._pre_all_gather(params, params_buffers)
+
+ mics_comm_groups: MiCS_CommGroups = params[0].comm
+ local_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group)
+ inter_node_comm_group = mics_comm_groups.param_inter_node_shard_group
+ intra_node_comm_group = mics_comm_groups.param_intra_node_group
+ param_shard_size = mics_comm_groups.param_shard_size
+
+ inter_node_size = dist.get_world_size(group=inter_node_comm_group)
+ intra_node_size = dist.get_world_size(group=intra_node_comm_group)
+ param_tensors = []
+ for i, p in enumerate(params):
+ param_size = p.ds_tensor.ds_numel * param_shard_size
+ if params_buffers is not None and params_buffers[i] is not None:
+ assert (
+ params_buffers[i].numel() == param_size
+ ), f"param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}"
+ param_tensor = params_buffers[i]
+ else:
+ param_tensor = torch.empty(
+ param_size, dtype=p.dtype, device=self.local_device, requires_grad=False
+ ).view(-1)
+ param_tensors.append(param_tensor)
+
+ # inter node all-gather
+ inter_outputs = []
+ inter_inputs = []
+ for i, p in enumerate(params):
+ inter_size = p.ds_tensor.ds_numel * inter_node_size
+ _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size)
+ inter_outputs.append(_out)
+ inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device))
+ # sync enqueue
+ dist.all_gather_coalesced(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False)
+
+ # intra node all-gather
+ intra_outputs = []
+ intra_inputs = []
+ for i, p in enumerate(params):
+ # partition param into multiple chunks for allgather
+ # because inter-node all-gather outputs are in a continues memory
+ # while in param memory, those inter-node data are placed in different
+ # location.
+ # each chunk is an intra-node output
+ param_chunk = (
+ param_tensors[i].view((inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1)
+ )
+ param_chunk.copy_(inter_outputs[i].detach().clone().view(param_chunk.size()))
+ output_chunks = torch.chunk(param_tensors[i], inter_node_size)
+ for j, _out in enumerate(output_chunks):
+ intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel
+ local_offset = local_rank * p.ds_tensor.ds_numel
+ _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel)
+ intra_outputs.append(_out)
+ intra_inputs.append(_in)
+
+ all_gather_handle = dist.all_gather_coalesced(
+ intra_outputs, intra_inputs, group=intra_node_comm_group, async_op=True
+ )
+ for i, param in enumerate(params):
+ param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data
+
+ return MiCS_AllGatherCoalescedHandle(
+ allgather_handle=all_gather_handle,
+ params=params,
+ partitions=[],
+ world_size=param_shard_size,
+ )
+
+ def get_partition_dp_group(self, param):
+ return param.comm.param_shard_group
+
+ def get_partition_rank(self):
+ return self.mics_comm_groups.param_shard_rank
+
+ @property
+ def num_partitions(self):
+ return self.mics_comm_groups.param_shard_size
+
+
+class MiCS_Offload(DeepSpeedZeRoOffload):
+ """Wrapper to change the behavior for parameter sharding"""
+
+ def __init__(
+ self,
+ module,
+ timers,
+ ds_config,
+ overlap_comm=True,
+ prefetch_bucket_size=50000000,
+ max_reuse_distance=1000000000,
+ max_live_parameters=1000000000,
+ param_persistence_threshold=100000,
+ model_persistence_threshold=sys.maxsize,
+ offload_param_config=None,
+ mpu=None,
+ ):
+ super().__init__(
+ module,
+ timers,
+ ds_config,
+ overlap_comm,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ model_persistence_threshold,
+ offload_param_config,
+ mpu,
+ )
+
+ def _convert_to_zero_parameters(self, ds_config, module, mpu):
+ """overload the parent class function for convert the parameters"""
+ log_dist(f"Convert to zero parameters from MiCS Offload manager", ranks=[0])
+ non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
+ if non_zero_params:
+ zero_params = [p for p in module.parameters() if is_zero_param(p)]
+ if zero_params:
+ zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
+ else:
+ group = None
+ if mpu:
+ group = mpu.get_data_parallel_group()
+
+ MiCS_Init(
+ module=module,
+ data_parallel_group=group,
+ dtype=self.dtype,
+ config_dict_or_path=ds_config,
+ remote_device=self.offload_device,
+ pin_memory=self.offload_param_pin_memory,
+ mpu=mpu,
+ )
+
+
+class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
+ """
+ MiCS Optimizer
+ """
+
+ def __init__(
+ self,
+ module,
+ init_optimizer,
+ timers,
+ ds_config,
+ static_loss_scale=1,
+ dynamic_loss_scale=False,
+ dynamic_loss_args=None,
+ verbose=True,
+ contiguous_gradients=True,
+ reduce_bucket_size=500000000,
+ prefetch_bucket_size=50000000,
+ max_reuse_distance=1000000000,
+ max_live_parameters=1000000000,
+ param_persistence_threshold=100000,
+ model_persistence_threshold=sys.maxsize,
+ dp_process_group=None,
+ reduce_scatter=True,
+ overlap_comm=False,
+ offload_optimizer_config=None,
+ offload_param_config=None,
+ sub_group_size=1000000000000,
+ mpu=None,
+ clip_grad=0,
+ communication_data_type=torch.float16,
+ postscale_gradients=True,
+ gradient_predivide_factor=1,
+ gradient_accumulation_steps=1,
+ elastic_checkpoint=False,
+ aio_config=None,
+ ):
+
+ log_dist("Init MiCS optimizer", ranks=[0])
+ super().__init__(
+ module,
+ init_optimizer,
+ timers,
+ ds_config,
+ static_loss_scale,
+ dynamic_loss_scale,
+ dynamic_loss_args,
+ verbose,
+ contiguous_gradients,
+ reduce_bucket_size,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ model_persistence_threshold,
+ dp_process_group,
+ reduce_scatter,
+ overlap_comm,
+ offload_optimizer_config,
+ offload_param_config,
+ sub_group_size,
+ mpu,
+ clip_grad,
+ communication_data_type,
+ postscale_gradients,
+ gradient_predivide_factor,
+ gradient_accumulation_steps,
+ elastic_checkpoint,
+ aio_config,
+ )
+ first_param = next(module.parameters())
+ # overload the dp_process_group and partition_count
+ assert hasattr(first_param, "comm"), " ".join(
+ [
+ "Sharded parameters don't have the MiCS_CommGroups attached.",
+ "Might due to the use of deepspeed.zero.Init context for initializing the weights.",
+ "To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter.",
+ ]
+ )
+ self.dp_process_group = first_param.comm.param_shard_group
+ self.partition_count = first_param.comm.param_shard_size
+
+ def initialize_ds_offload(
+ self,
+ module,
+ timers,
+ ds_config,
+ overlap_comm,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ model_persistence_threshold,
+ offload_param_config,
+ mpu,
+ ):
+ return MiCS_Offload(
+ module,
+ timers,
+ ds_config,
+ overlap_comm,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ model_persistence_threshold,
+ offload_param_config,
+ mpu,
+ )
+
+ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
+ grad_buffers = super().partition_grads(params_to_release, grad_partitions)
+ # perform all-reduce among replication groups
+ # the function will perform accumulation boundary check
+ self.allreduce_mics_shard_grads(params_to_release, grad_buffers)
+
+ @instrument_w_nvtx
+ def allreduce_mics_shard_grads(self, params, partitioned_grads_buffers: List[Tensor]):
+ """ """
+ # TODO: improve the condition check
+ if not self.is_gradient_accumulation_boundary or len(partitioned_grads_buffers) == 0:
+ return
+
+ mics_comm_groups: MiCS_CommGroups = params[0].comm
+ param_repli_group = mics_comm_groups.param_repli_group
+ param_repli_size = mics_comm_groups.param_repli_size
+
+ if param_repli_size is None or param_repli_size <= 1:
+ return
+ if not partitioned_grads_buffers[0].is_cuda:
+ raise RuntimeError("Local sharding has no support for CPU offloading")
+
+ if dist.has_all_reduce_coalesced():
+ scale_tensors(partitioned_grads_buffers, param_repli_size)
+ dist.all_reduce_coalesced(tensors=partitioned_grads_buffers, group=param_repli_group)
+ else:
+ # manually coalescing all-reduce
+ aggregated_buffer: Tensor = torch.cat(partitioned_grads_buffers)
+ aggregated_buffer.div_(param_repli_size)
+ dist.all_reduce(aggregated_buffer, group=param_repli_group)
+ offset = 0
+ for grad_buff in partitioned_grads_buffers:
+ grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel()))
+ offset += grad_buff.numel()
+
+ def load_state_dict(
+ self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, checkpoint_folder=None
+ ):
+ r"""Loading the ZeRO-3/MiCS partitioned checkpoints
+ Because the self.dp_process_group is replaced with the communicator for
+ partition group we can call the load_state_dict logic from ZeRO-3.
+ """
+ super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder)
diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bdd7dbc95d5185228ad73e042bd3b287175e6ad
--- /dev/null
+++ b/llava/train/llava_trainer.py
@@ -0,0 +1,1117 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+import json
+import os
+import random
+import time
+from typing import Dict, List, Optional
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.utils.data import ConcatDataset, Dataset, DistributedSampler, RandomSampler, Sampler
+from transformers import PreTrainedModel, Trainer
+from transformers.modeling_utils import unwrap_model
+from transformers.trainer import ALL_LAYERNORM_LAYERS # ShardedDDPOption,
+from transformers.trainer import get_parameter_names, has_length, is_sagemaker_mp_enabled, logger
+
+from llava.train.sequence_parallel import get_pg_manager
+from llava.trl.trainer import DPOTrainer
+import numpy as np
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, "no ignore status")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [
+ lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)
+ ]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class VILADistributedSampler(DistributedSampler):
+ """This class is implemented by Jason Lu."""
+
+ def __init__(
+ self,
+ dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ batch_size=None,
+ # NOTE: this is the total size but not per-worker
+ sample_len_list=None,
+ force_accumulation=True,
+ sp_degree: int = 1,
+ gradient_accumulation_steps: int = 1,
+ ) -> None:
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ if rank >= num_replicas or rank < 0:
+ raise ValueError(
+ "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1)
+ )
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.drop_last = True # always True
+ self.sp_degree = max(1, sp_degree)
+ self.bs_divisible_by_sp = batch_size % self.sp_degree == 0
+
+ # Consider sequence parallelism
+ if self.sp_degree > 1: # Sequence Parallelism is enabled
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ self.dp_rank = PROCESS_GROUP_MANAGER.dp_rank
+ self.dp_num_replicas = num_replicas // sp_degree
+ self.corresponding_ranks = list(range(self.dp_rank * self.sp_degree, (self.dp_rank + 1) * self.sp_degree))
+ else:
+ self.dp_rank = rank
+ self.dp_num_replicas = num_replicas
+
+ self.batch_size = batch_size
+ self.global_batch_size = batch_size * self.dp_num_replicas
+
+ # NOTE: org_ is without drop last
+ self.org_sample_len_list = self.per_replica_samples = sample_len_list
+ assert sum(sample_len_list) == len(self.dataset)
+
+ if self.drop_last: # type: ignore[arg-type]
+ self.per_replica_samples = [
+ sample_len
+ // (self.num_replicas * self.batch_size * gradient_accumulation_steps // self.sp_degree)
+ * self.batch_size
+ * gradient_accumulation_steps
+ // self.sp_degree
+ for sample_len in self.per_replica_samples
+ ]
+ self.num_samples = sum(self.per_replica_samples)
+ else:
+ raise NotImplementedError
+
+ self.total_size = self.num_samples * self.num_replicas
+ self.total_samples = [samples * self.num_replicas for samples in self.per_replica_samples]
+
+ self.shuffle = shuffle
+ self.seed = seed
+
+ # whether to force accumulate
+ self.force_accumulation = force_accumulation
+
+ def __len__(self) -> int:
+ return self.num_samples * self.sp_degree
+
+ def __iter__(self):
+
+ indices = list(range(len(self.dataset)))
+
+ # 1. split the full indices first (note: without drop last at this moment)
+ indices_list = []
+ for i in range(len(self.org_sample_len_list)):
+ indices_list.append(
+ indices[sum(self.org_sample_len_list[:i]) : sum(self.org_sample_len_list[:i]) + self.total_samples[i]]
+ )
+
+ assert sum([len(indices) for indices in indices_list]) == self.total_size, (
+ sum([len(indices) for indices in indices_list]),
+ self.total_size,
+ )
+
+ if (
+ self.sp_degree > 1 and self.bs_divisible_by_sp
+ ): # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism
+ dp_indices_dict = {} # {rank: indices_list}
+ all_indices_dict = {} # {rank: all_indices}
+
+ for i in self.corresponding_ranks:
+ dp_indices_list = []
+ for idx, indices in enumerate(indices_list):
+ dp_indices_list.append(
+ indices[i * self.per_replica_samples[idx] : (i + 1) * self.per_replica_samples[idx]]
+ )
+
+ random.seed(self.seed + self.epoch)
+ for indice in range(len(dp_indices_list)):
+ random.shuffle(dp_indices_list[indice])
+
+ dp_indices_dict[i] = dp_indices_list.copy()
+
+ for rank, dp_indices_list in dp_indices_dict.items():
+ dp_indices_list = sorted(dp_indices_list, key=lambda x: -len(x))
+ dp_all_indices = [-1] * self.num_samples
+ indices_available = list(range(self.num_samples))
+
+ for indice in dp_indices_list:
+
+ original_indices = range(len(indice))
+ transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices]
+
+ mapped_indices = [indices_available[idx] for idx in transformed_indices]
+ # update indices_available
+ for idx in reversed(transformed_indices):
+ del indices_available[idx]
+ for i, idx in enumerate(mapped_indices):
+ dp_all_indices[idx] = indice[i]
+
+ all_indices_dict[rank] = dp_all_indices
+
+ # Interleaving Merge
+ merged_indices = []
+ interleaved_indices = []
+ for item_idx in range(len(all_indices_dict[self.corresponding_ranks[0]])):
+ for rank in self.corresponding_ranks:
+ interleaved_indices.append(all_indices_dict[rank][item_idx])
+ merged_indices.append(interleaved_indices)
+
+ all_indices = merged_indices[0]
+ else:
+ # let's first do subsample
+ for idx, indices in enumerate(indices_list):
+ indices_list[idx] = indices[
+ self.rank * self.per_replica_samples[idx] : (self.rank + 1) * self.per_replica_samples[idx]
+ ]
+
+ random.seed(self.seed + self.epoch)
+ for indice in range(len(indices_list)):
+ random.shuffle(indices_list[indice])
+
+ indices_list = sorted(indices_list, key=lambda x: -len(x))
+ all_indices = [-1] * self.num_samples
+ indices_available = list(range(self.num_samples))
+
+ for indice in indices_list:
+
+ original_indices = range(len(indice))
+ transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices]
+
+ mapped_indices = [indices_available[idx] for idx in transformed_indices]
+ # update indices_available
+ for idx in reversed(transformed_indices):
+ del indices_available[idx]
+ for i, idx in enumerate(mapped_indices):
+ all_indices[idx] = indice[i]
+ assert -1 not in all_indices
+ return iter(all_indices)
+
+
+class LongVILADistributedSampler(VILADistributedSampler):
+ """This class is implemented by Yukang Chen."""
+
+ def __iter__(self):
+ def batch_shuffle(indices):
+ batch_indices = list(range(indices[0] // self.batch_size, indices[-1] // self.batch_size + 1))
+ random.shuffle(batch_indices)
+ indices_shuffled = [
+ batch_indices[i // self.batch_size] * self.batch_size + index % self.batch_size
+ for i, index in enumerate(indices)
+ ]
+ return indices_shuffled
+
+ indices = list(range(len(self.dataset)))
+
+ # 1. split the full indices first (note: without drop last at this moment)
+ indices_list = []
+ for i in range(len(self.org_sample_len_list)):
+ indices_list.append(
+ indices[sum(self.org_sample_len_list[:i]) : sum(self.org_sample_len_list[:i]) + self.total_samples[i]]
+ )
+
+ assert sum([len(indices) for indices in indices_list]) == self.total_size, (
+ sum([len(indices) for indices in indices_list]),
+ self.total_size,
+ )
+
+ if self.sp_degree > 1: # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism
+ dp_indices_dict = {} # {rank: indices_list}
+ all_indices_dict = {} # {rank: all_indices}
+
+ for i in self.corresponding_ranks:
+ dp_indices_list = []
+ for idx, indices in enumerate(indices_list):
+ dp_indices_list.append(
+ indices[i * self.per_replica_samples[idx] : (i + 1) * self.per_replica_samples[idx]]
+ )
+
+ random.seed(self.seed + self.epoch)
+ for indice in range(len(dp_indices_list)):
+ batch_shuffle(dp_indices_list[indice])
+
+ dp_indices_dict[i] = dp_indices_list.copy()
+
+ for rank, dp_indices_list in dp_indices_dict.items():
+ dp_indices_list = sorted(dp_indices_list, key=lambda x: -len(x))
+ dp_all_indices = [-1] * self.num_samples
+ indices_available = list(range(self.num_samples))
+
+ for indice in dp_indices_list:
+
+ original_indices = range(len(indice))
+ transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices]
+
+ mapped_indices = [indices_available[idx] for idx in transformed_indices]
+ # update indices_available
+ for idx in reversed(transformed_indices):
+ del indices_available[idx]
+ for i, idx in enumerate(mapped_indices):
+ dp_all_indices[idx] = indice[i]
+
+ all_indices_dict[rank] = dp_all_indices
+
+ # Interleaving Merge
+ merged_indices = []
+ interleaved_indices = []
+ for item_idx in range(len(all_indices_dict[self.corresponding_ranks[0]])):
+ for rank in self.corresponding_ranks:
+ interleaved_indices.append(all_indices_dict[rank][item_idx])
+ merged_indices.append(interleaved_indices)
+
+ all_indices = merged_indices[0]
+ else:
+ # let's first do subsample
+ for idx, indices in enumerate(indices_list):
+ indices_list[idx] = indices[
+ self.rank * self.per_replica_samples[idx] : (self.rank + 1) * self.per_replica_samples[idx]
+ ]
+
+ random.seed(self.seed + self.epoch)
+ for indice in range(len(indices_list)):
+ batch_shuffle(indices_list[indice])
+
+ indices_list = sorted(indices_list, key=lambda x: -len(x))
+ all_indices = [-1] * self.num_samples
+ indices_available = list(range(self.num_samples))
+ for indice in indices_list:
+ original_indices = range(len(indice))
+ transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices]
+ mapped_indices = [indices_available[idx] for idx in transformed_indices]
+ # update indices_available
+ for idx in reversed(transformed_indices):
+ del indices_available[idx]
+ for i, idx in enumerate(mapped_indices):
+ all_indices[idx] = indice[i]
+ assert -1 not in all_indices
+ return iter(all_indices)
+
+def get_length_grouped_batches(
+ lengths: List[int],
+ batch_size: int,
+ world_size: int,
+ generator=None,
+ merge: bool = True,
+) -> List:
+
+ N = len(lengths)
+ M = world_size * batch_size
+ if N < M:
+ # fallback: just random permute everything
+ idx = np.arange(N)
+ if generator is not None:
+ seed = generator.initial_seed()
+ rng = np.random.RandomState(seed)
+ else:
+ rng = np.random.RandomState()
+ rng.shuffle(idx)
+ if merge:
+ return idx.tolist()
+ else:
+ # one megabatch only
+ out = [idx.tolist()]
+ # pad to world_size empty lists if needed
+ return [out + [[]] * (world_size - 1)]
+
+ # 1) build RNG
+ if generator is not None:
+ seed = generator.initial_seed()
+ rng = np.random.RandomState(seed)
+ else:
+ rng = np.random.RandomState()
+
+ # 2) keys for lexsort: primary = -length, secondary = random
+ lengths_arr = np.array(lengths, dtype=np.int64)
+ key_length = -lengths_arr
+ key_rand = rng.permutation(N)
+
+ # 3) single global lexsort (last key is primary)
+ sorted_idx = np.lexsort((key_rand, key_length))
+
+ # 4) trim to full megabatches
+ num_mb = len(sorted_idx) // M
+ trimmed = sorted_idx[: num_mb * M]
+
+ # 5) reshape to [num_mb, M]
+ mb = trimmed.reshape(num_mb, M)
+
+ # 6) optional shuffle of whole megabatches
+ rng.shuffle(mb)
+
+ # 7) split each row into [world_size, batch_size]
+ mb = mb.reshape(num_mb, world_size, batch_size)
+
+ if merge:
+ # flatten in order megabatch โ replica โ sample
+ return mb.reshape(-1).tolist()
+ else:
+ # build nested Python lists: [ [ [..], [..], โฆ ], โฆ ]
+ return [
+ [mb[i, r].tolist() for r in range(world_size)]
+ for i in range(num_mb)
+ ]
+
+
+# def get_length_grouped_batches(
+# lengths: List[int],
+# batch_size: int,
+# world_size: int,
+# generator=None,
+# merge: bool = True,
+# ) -> List:
+# """
+# Create length-grouped megabatches.
+
+# First, a random permutation of indices is computed. Then we split
+# into megabatches of size (world_size * batch_size) and sort each
+# megabatch by descending length. Finally, each megabatch is split
+# into `world_size` chunks (one per replica).
+
+# If merge is True, a flat list is returned; if False, the nested
+# structure is kept.
+# """
+# indices = torch.randperm(len(lengths), generator=generator)
+# megabatch_size = world_size * batch_size
+# # Partition indices into megabatches
+# megabatches = [
+# indices[i : i + megabatch_size].tolist()
+# for i in range(0, len(lengths), megabatch_size)
+# ]
+# # Within each megabatch, sort indices in descending order of length.
+# sorted_megabatches = [
+# sorted(megabatch, key=lambda i: lengths[i], reverse=True)
+# for megabatch in megabatches
+# ]
+# # Split each sorted megabatch evenly among replicas.
+# split_megabatches = [
+# split_to_even_chunks(megabatch, lengths, world_size)
+# for megabatch in sorted_megabatches
+# ]
+# if merge:
+# # Flatten into a single list.
+# return [i for megabatch in split_megabatches for batch in megabatch for i in batch]
+# else:
+# # Return the nested structure: list of megabatches, each containing a list (of length world_size) of batches.
+# return split_megabatches
+
+class LengthGroupedVILADistributedSampler(DistributedSampler):
+ """
+ A sampler that groups examples by (approximate) length and then
+ distributes them across replicas following VILAโs accumulation logic.
+
+ Parameters:
+ - dataset: the dataset to sample from.
+ - batch_size: batch size per replica.
+ - lengths: a list of lengths (one per example in the dataset).
+ - num_replicas: total number of distributed replicas (if not provided,
+ will be inferred from torch.distributed).
+ - rank: the rank of the current process.
+ - shuffle: whether to shuffle groups.
+ - seed: base random seed.
+ - drop_last: whether to drop the tail of incomplete megabatches (set True).
+ - sp_degree: sequence-parallel degree.
+ - gradient_accumulation_steps: used for scaling the effective batch size.
+ - group_by_modality: if True, you might call a different grouping function.
+ - generator: optional torch.Generator for determinism.
+ - force_accumulation: whether to force the VILA accumulation ordering.
+ """
+ def __init__(
+ self,
+ dataset,
+ batch_size: int,
+ lengths: List[int],
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = True,
+ sp_degree: int = 1,
+ gradient_accumulation_steps: int = 1,
+ group_by_modality: bool = True,
+ generator=None,
+ force_accumulation: bool = True,
+ ):
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank,
+ shuffle=shuffle, seed=seed, drop_last=drop_last)
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+ self.sp_degree = max(1, sp_degree)
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ self.force_accumulation = force_accumulation
+ self.seed = seed
+ self.epoch = 0 # This should be updated externally at each epoch.
+
+ self.world_size = self.num_replicas # from DistributedSampler
+
+ self.bs_divisible_by_sp = (batch_size % self.sp_degree == 0)
+ if self.sp_degree > 1:
+ # Get sequence parallelism group info.
+ PROCESS_GROUP_MANAGER = get_pg_manager() # Must be implemented.
+ self.dp_rank = PROCESS_GROUP_MANAGER.dp_rank
+ self.dp_num_replicas = self.num_replicas // self.sp_degree
+ self.corresponding_ranks = list(range(self.dp_rank * self.sp_degree, (self.dp_rank + 1) * self.sp_degree))
+ else:
+ self.dp_rank = self.rank
+ self.dp_num_replicas = self.num_replicas
+
+ # Compute the number of full megabatches (each of size world_size * batch_size).
+ megabatch_size = self.world_size * self.batch_size
+ num_full_megabatches = len(self.dataset) // megabatch_size
+ # For each full megabatch, each replica gets batch_size examples.
+ self.num_samples = num_full_megabatches * self.batch_size
+
+ def __len__(self) -> int:
+ # When using sequence parallelism, the effective number may be scaled.
+ return self.num_samples * (self.sp_degree if self.sp_degree > 1 else 1)
+
+ def __iter__(self):
+ # Get the nested list of length-grouped batches.
+ # Each element in "megabatches" is a list of length world_size, one per replica.
+ megabatches = get_length_grouped_batches(
+ self.lengths,
+ self.batch_size,
+ self.world_size,
+ generator=self.generator,
+ merge=False,
+ )
+ # For each megabatch, select the batch corresponding to this replica.
+ indices_list = []
+ for megabatch in megabatches:
+ if self.rank < len(megabatch):
+ indices_list.append(megabatch[self.rank])
+ total_samples = sum(len(lst) for lst in indices_list)
+
+ if self.sp_degree > 1 and self.bs_divisible_by_sp:
+ # --- Sequence Parallelism branch ---
+ # For each of the corresponding sequence-parallel ranks, split each batch.
+ dp_indices_dict = {}
+ all_indices_dict = {}
+ for r in self.corresponding_ranks:
+ dp_indices_list = []
+ for lst in indices_list:
+ # Split each list into sp_degree equal parts.
+ part_size = len(lst) // self.sp_degree
+ dp_indices_list.append(lst[r * part_size : (r + 1) * part_size])
+ random.seed(self.seed + self.epoch)
+ for sublist in dp_indices_list:
+ random.shuffle(sublist)
+ dp_indices_dict[r] = dp_indices_list.copy()
+ # Now, for each sequence-parallel rank, remap the indices.
+ for r, dp_list in dp_indices_dict.items():
+ # Sort the sublists by descending length.
+ dp_list = sorted(dp_list, key=lambda x: -len(x))
+ num_samples_r = sum(len(x) for x in dp_list)
+ dp_all_indices = [-1] * num_samples_r
+ indices_available = list(range(num_samples_r))
+ for sublist in dp_list:
+ n = len(sublist)
+ transformed_indices = [i * len(indices_available) // n for i in range(n)]
+ mapped_indices = [indices_available[j] for j in transformed_indices]
+ for j in sorted(transformed_indices, reverse=True):
+ del indices_available[j]
+ for i, pos in enumerate(mapped_indices):
+ dp_all_indices[pos] = sublist[i]
+ all_indices_dict[r] = dp_all_indices
+ # Interleave the indices from all sequence-parallel ranks.
+ merged_indices = []
+ # Assumes each dp_all_indices list is of the same length.
+ interleaved_length = len(next(iter(all_indices_dict.values())))
+ for i in range(interleaved_length):
+ for r in self.corresponding_ranks:
+ merged_indices.append(all_indices_dict[r][i])
+ final_indices = merged_indices
+ else:
+ # --- Non-sequence-parallel branch ---
+ random.seed(self.seed + self.epoch)
+ for sublist in indices_list:
+ random.shuffle(sublist)
+ # Sort the groups by descending length.
+ indices_list = sorted(indices_list, key=lambda x: -len(x))
+ dp_all_indices = [-1] * total_samples
+ indices_available = list(range(total_samples))
+ for sublist in indices_list:
+ n = len(sublist)
+ transformed_indices = [i * len(indices_available) // n for i in range(n)]
+ mapped_indices = [indices_available[j] for j in transformed_indices]
+ for j in sorted(transformed_indices, reverse=True):
+ del indices_available[j]
+ for i, pos in enumerate(mapped_indices):
+ dp_all_indices[pos] = sublist[i]
+ final_indices = dp_all_indices
+
+ assert -1 not in final_indices, "Some indices were not assigned properly."
+ return iter(final_indices)
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(
+ self.lengths, self.batch_size, self.world_size, generator=self.generator
+ )
+ else:
+ indices = get_length_grouped_indices(
+ self.lengths, self.batch_size, self.world_size, generator=self.generator
+ )
+ return iter(indices)
+
+
+class VILADPOTrainer(DPOTrainer):
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ # Always using Jason's sampler.
+ sample_len_list = self.args.sample_lens
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+ num_replicas = self.args.world_size
+ rank = self.args.process_index
+ return VILADistributedSampler(
+ self.train_dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ seed=seed,
+ batch_size=self.args.train_batch_size,
+ sample_len_list=sample_len_list,
+ sp_degree=self.args.seq_parallel_size,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+ if self.eval_dataset is None or not has_length(self.eval_dataset):
+ return None
+
+ # Always using Jason's sampler.
+ sample_len_list = self.args.eval_sample_lens
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+ return VILADistributedSampler(
+ eval_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ batch_size=self.args.eval_batch_size,
+ sample_len_list=sample_len_list,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+ # if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ # return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ if self.args.mm_projector_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.mm_projector_lr,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.mm_projector_lr,
+ },
+ ]
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if 0: # self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def save_model(self, output_dir: Optional[str], _internal_call: bool):
+ ## save tuned model separately
+ if self.is_deepspeed_enabled:
+ state_dict = self.accelerator.get_state_dict(self.deepspeed)
+ else:
+ # TODO(ligeng): fix save_model for multi-node training on large models (e.g., Llama-70b)
+ state_dict = self.model.state_dict()
+
+ if self.args.should_save:
+ return self.model.save_pretrained(output_dir, state_dict=state_dict)
+
+
+class LLaVATrainer(Trainer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.model_accepts_loss_kwargs = True
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ print('AF3 sampler')
+ sample_len_list = self.args.sample_lens
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+ num_replicas = self.args.world_size
+ rank = self.args.process_index
+ longvila_sampler = self.args.longvila_sampler
+
+ if self.args.group_by_modality_length:
+ sampler = LengthGroupedVILADistributedSampler
+ if not isinstance(self.train_dataset, ConcatDataset):
+ lengths = self.train_dataset.modality_lengths
+ else:
+ lengths = []
+ for d in self.train_dataset.datasets:
+ lengths += d.modality_lengths
+
+ return sampler(
+ self.train_dataset,
+ lengths=lengths,
+ num_replicas=num_replicas,
+ rank=rank,
+ seed=seed,
+ batch_size=self.args.train_batch_size,
+ sp_degree=self.args.seq_parallel_size,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ group_by_modality=True
+ )
+ else:
+ sampler = LongVILADistributedSampler if longvila_sampler else VILADistributedSampler
+ return sampler(
+ self.train_dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ seed=seed,
+ batch_size=self.args.train_batch_size,
+ sample_len_list=sample_len_list,
+ sp_degree=self.args.seq_parallel_size,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+ if self.eval_dataset is None or not has_length(self.eval_dataset):
+ return None
+
+ sample_len_list = self.args.eval_sample_lens
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+ return VILADistributedSampler(
+ eval_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ batch_size=self.args.eval_batch_size,
+ sample_len_list=sample_len_list,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+ def _inner_training_loop(self, batch_size: Optional[int] = None, *args, **kwargs):
+ # NOTE(zhijianl): In the latest transformers, if the batch size in the training arguments differs from
+ # the one in the training state, the batch size from the state is used by default. This can be
+ # problematic when resuming with different batch sizes or gradient accumulation steps. To prevent this,
+ # we enforce using the batch size specified in the training arguments.
+ batch_size = self.args.train_batch_size
+ return super()._inner_training_loop(batch_size, *args, **kwargs)
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+ # if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ # return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ if self.args.mm_projector_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.mm_projector_lr,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.mm_projector_lr,
+ },
+ ]
+ elif self.args.vision_tower_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "vision_tower" in name]
+ # projector_lora_A_parameters = [name for name in projector_parameters if "lora_A" in name]
+ # projector_lora_B_parameters = [name for name in projector_parameters if "lora_B" in name]
+ # other_lora_A_parameters = [name for name in opt_model.named_parameters() if "lora_A" in name and name not in projector_parameters]
+ # other_lora_B_parameters = [name for name in opt_model.named_parameters() if "lora_B" in name and name not in projector_parameters]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.vision_tower_lr,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.vision_tower_lr,
+ },
+ ]
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if 0: # self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def save_model(self, output_dir: Optional[str], _internal_call: bool):
+ ## save tuned model separately
+ if self.is_deepspeed_enabled:
+ state_dict = self.accelerator.get_state_dict(self.deepspeed)
+ else:
+ # TODO(ligeng): fix save_model for multi-node training on large models (e.g., Llama-70b)
+ state_dict = self.model.state_dict()
+
+ if self.args.lora_enable:
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters())
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(
+ non_lora_state_dict,
+ os.path.join(output_dir, "non_lora_trainables.bin"),
+ )
+ # config
+ self.model._name_or_path = output_dir
+ self.model.architectures = [self.model.__class__.__name__]
+ self.model.config.save_pretrained(output_dir)
+
+ if self.args.should_save:
+ return self.model.save_pretrained(output_dir, state_dict=state_dict)
+
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training.
+
+ Subclass and override this method to inject custom behavior.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ if self.state.epoch is not None:
+ logs["epoch"] = round(self.state.epoch, 2)
+ if self.args.include_num_input_tokens_seen:
+ logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+
+ if self.args.debug_e2e and self.control.should_training_stop:
+
+ # Only save log history if the current process is rank 0
+ if dist.get_rank() == 0:
+ with open(f"{self.args.output_dir}/log_history.json", "w") as f:
+ json.dump(self.state.log_history, f, indent=4)
+
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
diff --git a/llava/train/sequence_parallel/__init__.py b/llava/train/sequence_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f67c6eb59a4889f6ee2380370bee77d6e8e24628
--- /dev/null
+++ b/llava/train/sequence_parallel/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .globals import get_pg_manager, set_pg_manager
+from .input_utils import extract_local_from_list, extract_local_input_ids, extract_local_position_ids
diff --git a/llava/train/sequence_parallel/all_to_all.py b/llava/train/sequence_parallel/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..887a03afa4fce6b837961f26c0c95ffd04c628d3
--- /dev/null
+++ b/llava/train/sequence_parallel/all_to_all.py
@@ -0,0 +1,292 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This file is modified from https://github.com/feifeibear/long-context-attention
+# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
+
+
+from typing import Any, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.nn import Module
+
+from llava.train.sequence_parallel.globals import (
+ get_ulysses_seq_len,
+ get_ulysses_sp_pg,
+ get_ulysses_sp_rank,
+ get_ulysses_sp_size,
+ set_ulysses_seq_len,
+)
+
+
+def all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
+ """
+ all-to-all for QKV
+
+ Args:
+ input (torch.tensor): a tensor sharded along dim scatter dim
+ scatter_idx (int): default 1
+ gather_idx (int): default 2
+ group : torch process group
+
+ Returns:
+ torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
+ """
+ assert input.dim() == 4, f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
+
+ # seq_world_size = dist.get_world_size(group)
+ # (DL): Change to ulysses size to handle hybrid parallelism.
+ seq_world_size = get_ulysses_sp_size()
+ if scatter_idx == 2 and gather_idx == 1:
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
+ bs, shard_seqlen, hc, hs = input.shape
+ # (Dacheng): For multi-modality use case, sequence length is different, causing unknown behavior for a2a.
+ # Pad it first.
+ # (Dacheng): This will trigger for each attention to make sure the second a2a is correct.
+ # (TODO) Maybe can optimize to per forward call.
+ ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=input.device) for _ in range(get_ulysses_sp_size())]
+ dist.barrier(group=get_ulysses_sp_pg())
+ dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=input.device), group=get_ulysses_sp_pg())
+ set_ulysses_seq_len(ulysses_seq_len)
+
+ max_global_length = max(ulysses_seq_len)
+ # pad to the second dimension to the longest
+ input = torch.nn.functional.pad(input, (0, 0, 0, 0, 0, max_global_length - shard_seqlen))
+
+ seqlen = max_global_length * seq_world_size # shard_seqlen * seq_world_size
+ shard_hc = hc // seq_world_size
+
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
+ # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
+ input_t = (
+ # input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs)
+ input.reshape(bs, max_global_length, seq_world_size, shard_hc, hs)
+ .transpose(0, 2)
+ .contiguous()
+ )
+
+ output = torch.empty_like(input_t)
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
+ dist.barrier(group=group)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # if scattering the seq-dim, transpose the heads back to the original dimension
+ output = output.reshape(seqlen, bs, shard_hc, hs)
+
+ # then we will unpad it back
+ output_list = torch.split(output, max_global_length, dim=0)
+ assert len(output_list) == get_ulysses_sp_size()
+ unpadded_output_list = [_output[: _seqlen.item()] for _output, _seqlen in zip(output_list, ulysses_seq_len)]
+
+ # Concatenate the unpadded tensors back together
+ output = torch.cat(unpadded_output_list)
+
+ # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
+ output = output.transpose(0, 1).contiguous().reshape(bs, sum(ulysses_seq_len), shard_hc, hs)
+
+ # assert False
+
+ return output
+
+ elif scatter_idx == 1 and gather_idx == 2:
+ ulysses_seq_len = get_ulysses_seq_len()
+ assert ulysses_seq_len is not None, "the second a2a (scatter 1, gather 2) is called at first."
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
+ bs, _, shard_hc, hs = input.shape
+ hc = shard_hc * seq_world_size
+
+ # First we need to recover how to pad
+ max_global_length = max(ulysses_seq_len)
+
+ unpadded_input_list = torch.split(input, ulysses_seq_len, dim=1)
+ padded_input_list = [
+ torch.nn.functional.pad(_unpadded_input, (0, 0, 0, 0, 0, max_global_length - _unpadded_input.shape[1]))
+ for _unpadded_input in unpadded_input_list
+ ]
+ input = torch.cat(padded_input_list, dim=1)
+
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
+ # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
+ input_t = (
+ input.reshape(bs, seq_world_size, max_global_length, shard_hc, hs)
+ .transpose(0, 3)
+ .transpose(0, 1)
+ .contiguous()
+ .reshape(seq_world_size, shard_hc, max_global_length, bs, hs)
+ )
+
+ output = torch.empty_like(input_t)
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
+ dist.barrier(group=group)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # if scattering the seq-dim, transpose the heads back to the original dimension
+ output = output.reshape(hc, max_global_length, bs, hs)
+
+ # unpad the output
+ self_length = ulysses_seq_len[get_ulysses_sp_rank()]
+ # print(f"Self length {self_length}")
+ output = output[:, :self_length, :, :]
+
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
+ output = output.transpose(0, 2).contiguous().reshape(bs, self_length, hc, hs)
+ return output
+ else:
+ raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
+
+
+class SeqAllToAll4D(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ input: Tensor,
+ scatter_idx: int,
+ gather_idx: int,
+ ) -> Tensor:
+
+ ctx.group = group
+ ctx.scatter_idx = scatter_idx
+ ctx.gather_idx = gather_idx
+
+ return all_to_all_4D(input, scatter_idx, gather_idx, group=group)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
+ return (
+ None,
+ SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
+ None,
+ None,
+ )
+
+
+def all_to_all_5D(input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None) -> torch.tensor:
+ """
+ all-to-all for QKV
+ forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs)
+
+ Args:
+ input (torch.tensor): a tensor sharded along dim scatter dim
+ scatter_idx (int): default 1
+ gather_idx (int): default 2
+ group : torch process group
+
+ Returns:
+ torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs)
+ """
+ assert input.dim() == 5, f"input must be 5D tensor, got {input.dim()} and shape {input.shape}"
+
+ seq_world_size = dist.get_world_size(group)
+
+ if scatter_idx == 3 and gather_idx == 1:
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs)
+ bs, shard_seqlen, t_cnt, hc, hs = input.shape
+
+ assert t_cnt == 3
+ seqlen = shard_seqlen * seq_world_size
+ shard_hc = hc // seq_world_size
+
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
+ # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs)
+ input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous()
+
+ output = torch.empty_like(input_t)
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
+ # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head
+ dist.barrier(group=group)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # if scattering the seq-dim, transpose the heads back to the original dimension
+ output = output.reshape(seqlen, 3, bs, shard_hc, hs)
+
+ # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs)
+ output = output.transpose(0, 2).transpose(1, 2).contiguous()
+
+ return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous()
+ elif scatter_idx == 1 and gather_idx == 3:
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
+ bs, seqlen, _, shard_hc, hs = input.shape
+ hc = shard_hc * seq_world_size
+ shard_seqlen = seqlen // seq_world_size
+ seq_world_size = dist.get_world_size(group)
+
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
+ # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs)
+ input_t = (
+ input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs)
+ .transpose(0, 4)
+ .transpose(0, 1)
+ .contiguous()
+ .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs)
+ )
+
+ output = torch.empty_like(input_t)
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
+ dist.barrier(group=group)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # if scattering the seq-dim, transpose the heads back to the original dimension
+ output = output.reshape(hc, shard_seqlen, 3, bs, hs)
+
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
+ output = output.transpose(0, 3).contiguous()
+
+ return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous()
+ else:
+ raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3")
+
+
+class SeqAllToAll5D(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ input: Tensor,
+ scatter_idx: int = 3,
+ gather_idx: int = 1,
+ ) -> Tensor:
+
+ ctx.group = group
+ ctx.scatter_idx = scatter_idx
+ ctx.gather_idx = gather_idx
+
+ return all_to_all_5D(input, scatter_idx, gather_idx, group=group)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
+ return (
+ None,
+ SeqAllToAll5D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
+ None,
+ None,
+ )
+
+
+class SeqAllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx: Any, group: dist.ProcessGroup, input: Any) -> Tensor:
+ # ctx.group = group
+ ctx.save_for_backward(input[0])
+ all_gather_list = input[0]
+ all_gather_tensor = input[1]
+ dist.all_gather(all_gather_list, all_gather_tensor, group=group)
+ # torch.concat
+ return torch.stack(all_gather_list, dim=0)
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
+ (tensor,) = ctx.saved_tensors
+ return None, (None, tensor)
diff --git a/llava/train/sequence_parallel/globals.py b/llava/train/sequence_parallel/globals.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8652e21b6682e610d23455fbdaa1f226115bdd7
--- /dev/null
+++ b/llava/train/sequence_parallel/globals.py
@@ -0,0 +1,274 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/feifeibear/long-context-attention
+# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
+
+import os
+
+import deepspeed.comm as dist
+import torch
+
+
+class Singleton:
+ _instance = None
+
+ def __new__(cls, *args, **kwargs):
+ if not cls._instance:
+ cls._instance = super().__new__(cls)
+ cls._instance.__initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if not self.__initialized:
+ self.__initialized = True
+
+
+class ProcessGroupManager(Singleton):
+ """
+ sp_degree = sp_ring_degree x sp_ulysses_degree
+ """
+
+ def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type):
+ if not hasattr(self, "__initialized"):
+ super().__init__()
+ self.ulysses_degree = ulysses_degree
+ self.ring_type = ring_type
+ self.ulysses_seq_len = None
+
+ self.ring_degree = ring_degree
+ self.sp_degree = ring_degree * ulysses_degree
+ self.dp_degree = dp_degree
+
+ self.rank = dist.get_rank()
+
+ if self.ring_degree == 1:
+ # Using Ulysses Sequence Parallelism only
+ num_ulysses_pgs = self.dp_degree
+ self.ring_pg = None
+ self.ring_rank = None
+
+ for i in range(num_ulysses_pgs):
+ ulysses_ranks = list(range(i * self.ulysses_degree, (i + 1) * self.ulysses_degree))
+ group = dist.new_group(ulysses_ranks)
+ if self.rank in ulysses_ranks:
+ self.ulysses_pg = group
+
+ for sp_rank in range(self.sp_degree):
+ dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree))
+ group = dist.new_group(dp_ranks)
+ if self.rank in dp_ranks:
+ self.dp_pg = group
+
+ self.ulysses_rank = dist.get_rank(self.ulysses_pg)
+ self.sp_rank = self.ulysses_rank
+ self.dp_rank = dist.get_rank(self.dp_pg)
+ self.sp_pg = self.ulysses_pg
+
+ print(f"GPU {torch.cuda.current_device()} Ulysses rank: {self.ulysses_rank} out of {self.sp_degree}")
+ else:
+ # Using Hybrid Sequence Parallelism
+ assert self.ring_degree > 1
+ num_ulysses_pgs = self.ring_degree # world_size // self.ulysses_degree
+ num_ring_pgs = self.ulysses_degree # world_size // self.ring_degree
+
+ # Set up process groups
+ if use_ulysses_low:
+ for dp_rank in range(dp_degree):
+ offset = dp_rank * self.sp_degree
+ for i in range(num_ulysses_pgs):
+ ulysses_ranks = list(
+ range(
+ i * self.ulysses_degree + offset,
+ (i + 1) * self.ulysses_degree + offset,
+ )
+ )
+ group = dist.new_group(ulysses_ranks)
+ if self.rank in ulysses_ranks:
+ self.ulysses_pg = group
+
+ for i in range(num_ring_pgs):
+ ring_ranks = list(range(i + offset, self.sp_degree + offset, num_ring_pgs))
+ group = dist.new_group(ring_ranks)
+ if self.rank in ring_ranks:
+ self.ring_pg = group
+
+ else:
+ for dp_rank in range(dp_degree):
+ offset = dp_rank * self.sp_degree
+ for i in range(num_ring_pgs):
+ ring_ranks = list(range(i * self.ring_degree + offset, (i + 1) * self.ring_degree + offset))
+ group = dist.new_group(ring_ranks)
+ if self.rank in ring_ranks:
+ self.ring_pg = group
+
+ for i in range(num_ulysses_pgs):
+ ulysses_ranks = list(range(i + offset, self.sp_degree + offset, num_ulysses_pgs))
+ group = dist.new_group(ulysses_ranks)
+ if self.rank in ulysses_ranks:
+ self.ulysses_pg = group
+
+ for sp_rank in range(self.sp_degree):
+ dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree))
+ group = dist.new_group(dp_ranks)
+ if self.rank in dp_ranks:
+ self.dp_pg = group
+
+ for i in range(self.dp_degree):
+ sp_ranks = list(range(i * self.sp_degree, (i + 1) * self.sp_degree))
+ group = dist.new_group(sp_ranks)
+ if self.rank in sp_ranks:
+ self.sp_pg = group
+
+ self.ulysses_rank = dist.get_rank(self.ulysses_pg)
+ self.ring_rank = dist.get_rank(self.ring_pg)
+ self.dp_rank = dist.get_rank(self.dp_pg)
+
+ if use_ulysses_low:
+ self.sp_rank = self.ulysses_rank + self.ring_rank * self.ulysses_degree
+ else:
+ self.sp_rank = self.ring_rank + self.ulysses_rank * self.ring_degree
+
+ print(
+ f"Rank {self.rank}, GPU {torch.cuda.current_device()} Hybrid SP rank: {self.sp_rank} out of {self.sp_degree} (Ulysses: {self.ulysses_rank}/{self.ulysses_degree}, Ring: {self.ring_rank}/{self.ring_degree})"
+ )
+
+ print("--------------ProcessGroupManager Initialized---------------------")
+
+
+PROCESS_GROUP_MANAGER = None
+
+
+def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None):
+ """
+ Set the process group manager for sequence parallelism.
+ sp_degree = sp_ring_degree x sp_ulysses_degree
+ """
+
+ # first check torch distributed group init and set device accordingly;
+ # (DL) TODO: Whether this can be skipped in DeepSpeed.
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ print(
+ "torch distributed is already initialized, " "skipping initialization ...",
+ flush=True,
+ )
+ else:
+ if int(os.environ["RANK"]) == 0:
+ print("Initializing Torch distributed.")
+ dist.init_distributed(dist_backend="nccl", dist_init_required=True)
+ local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
+
+ torch.cuda.set_device(dist.get_rank() % local_world_size)
+
+ world_size = dist.get_world_size()
+
+ assert sp_degree <= world_size
+ assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_degree} != 0"
+
+ if sp_ring_degree < 1:
+ sp_ring_degree = 1
+
+ sp_ulysses_degree = sp_degree // sp_ring_degree
+ assert sp_degree % sp_ring_degree == 0, f"sp_degree {sp_degree} % sp_ring_degree {sp_ring_degree} != 0"
+
+ dp_degree = world_size // sp_degree
+
+ # Init the process group manager
+ global PROCESS_GROUP_MANAGER
+ PROCESS_GROUP_MANAGER = ProcessGroupManager(
+ sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type
+ )
+
+
+def get_pg_manager():
+ return PROCESS_GROUP_MANAGER
+
+
+def get_sequence_parallel_size():
+ """Get the size of the sequence parallel group."""
+ return PROCESS_GROUP_MANAGER.sp_degree
+
+
+def get_sequence_parallel_rank():
+ """Get the rank of this process in the sequence parallel group the caller rank belongs to."""
+ return PROCESS_GROUP_MANAGER.sp_rank
+
+
+def get_sequence_parallel_pg():
+ """Get the overall sequence parallel process group (include Ring and Ulysses)."""
+ return PROCESS_GROUP_MANAGER.sp_pg
+
+
+def get_ulysses_sp_size():
+ """Get the size of the Ulysses sequence parallel group."""
+ return PROCESS_GROUP_MANAGER.ulysses_degree
+
+
+def get_ulysses_seq_len():
+ """Get the size of the Ulysses sequence parallel group."""
+ return PROCESS_GROUP_MANAGER.ulysses_seq_len
+
+
+def set_ulysses_seq_len(seq_len):
+ """Get the size of the Ulysses sequence parallel group."""
+ PROCESS_GROUP_MANAGER.ulysses_seq_len = seq_len
+
+
+def get_ulysses_sp_rank():
+ """Get the rank of this process in the Ulysses sequence parallel group the caller rank belongs to."""
+ return PROCESS_GROUP_MANAGER.ulysses_rank
+
+
+def get_ulysses_sp_pg():
+ """Get the Ulysses sequence parallel process group."""
+ return PROCESS_GROUP_MANAGER.ulysses_pg
+
+
+def get_ring_sp_size():
+ """Get the size of the RingAttn sequence parallel group."""
+ return PROCESS_GROUP_MANAGER.ring_degree
+
+
+def get_ring_sp_rank():
+ """Get the rank of this process in the RingAttn sequence parallel group the caller rank belongs to."""
+ return PROCESS_GROUP_MANAGER.ring_rank
+
+
+def get_ring_sp_pg():
+ """Get the RingAttn sequence parallel process group."""
+ return PROCESS_GROUP_MANAGER.ring_pg
+
+
+def get_ring_type():
+ """Get the RingAttn implementation type."""
+ return PROCESS_GROUP_MANAGER.ring_type
+
+
+def get_data_parallel_size():
+ """Get the size of the data parallel group."""
+ return PROCESS_GROUP_MANAGER.dp_degree
+
+
+def get_data_parallel_rank():
+ """Get the rank of this process in the data parallel group the caller rank belongs to."""
+ return PROCESS_GROUP_MANAGER.dp_rank
diff --git a/llava/train/sequence_parallel/hybrid_attn.py b/llava/train/sequence_parallel/hybrid_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2186ce00f8f8585a08d59a221f2cb30ecae73441
--- /dev/null
+++ b/llava/train/sequence_parallel/hybrid_attn.py
@@ -0,0 +1,466 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/feifeibear/long-context-attention
+# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
+
+import copy
+from typing import Any
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.nn import Module
+
+from .all_to_all import SeqAllToAll4D, SeqAllToAll5D
+from .globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg
+from .ring import (
+ ring_flash_attn_func,
+ ring_flash_attn_qkvpacked_func,
+ ring_flash_attn_varlen_func,
+ ring_flash_attn_varlen_qkvpacked_func,
+ stripe_flash_attn_func,
+ stripe_flash_attn_qkvpacked_func,
+ zigzag_ring_flash_attn_func,
+ zigzag_ring_flash_attn_qkvpacked_func,
+ zigzag_ring_flash_attn_varlen_func,
+ zigzag_ring_flash_attn_varlen_qkvpacked_func,
+)
+
+RING_IMPL_DICT = {
+ "ring": ring_flash_attn_func,
+ "zigzag": zigzag_ring_flash_attn_func,
+ "strip": stripe_flash_attn_func,
+ "ring_varlen": ring_flash_attn_varlen_func,
+ "zigzag_ring_varlen": zigzag_ring_flash_attn_varlen_func,
+}
+
+RING_IMPL_QKVPACKED_DICT = {
+ "ring": ring_flash_attn_qkvpacked_func,
+ "zigzag": zigzag_ring_flash_attn_qkvpacked_func,
+ "strip": stripe_flash_attn_qkvpacked_func,
+ "ring_varlen": ring_flash_attn_varlen_qkvpacked_func,
+ "zigzag_varlen": zigzag_ring_flash_attn_varlen_qkvpacked_func,
+}
+
+
+class HybridAttention(torch.nn.Module):
+ """Initialization.
+
+ Arguments:
+ ulysses_pg (ProcessGroup): ulysses process group
+ ring_pg (ProcessGroup): ring process group
+ scatter_idx (int): scatter_idx for all2all comm
+ gather_idx (int): gather_idx for all2all comm
+ """
+
+ def __init__(
+ self,
+ scatter_idx: int = 2,
+ gather_idx: int = 1,
+ use_pack_qkv: bool = False,
+ attention_warper: Module = None,
+ ) -> None:
+
+ super().__init__()
+ self.ring_pg = get_ring_sp_pg()
+ self.ulysses_pg = get_ulysses_sp_pg()
+
+ self.use_pack_qkv = use_pack_qkv
+ assert (
+ self.ulysses_pg is not None or self.ring_pg is not None
+ ), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
+ self.scatter_idx = scatter_idx
+ self.gather_idx = gather_idx
+ if attention_warper is None:
+ self.ring_attn_fn = RING_IMPL_DICT[get_ring_type()]
+ else:
+ self.ring_attn_fn = attention_warper
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ *args: Any,
+ attention_mask=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ seqlens_in_batch=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ ) -> Tensor:
+ """forward
+
+ Arguments:
+ query (Tensor): query input to the layer
+ key (Tensor): key input to the layer
+ value (Tensor): value input to the layer
+ args: other args
+
+ Returns:
+ * output (Tensor): context output
+ """
+
+ # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
+ # scatter 2, gather 1
+ if self.use_pack_qkv:
+
+ # TODO (Qinghao): To support packed qkv
+ raise NotImplementedError("Packed qkv is not supported yet.")
+ # (3*bs, seq_len/N, head_cnt, head_size)
+ qkv = torch.cat([query, key, value]).continous()
+ # (3*bs, seq_len, head_cnt/N, head_size)
+ qkv = SeqAllToAll4D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx)
+ qkv = torch.chunk(qkv, 3, dim=0)
+ out = self.ring_attn_fn(
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ group=self.ring_pg,
+ )
+
+ query_layer = SeqAllToAll4D.apply(self.ulysses_pg, query, self.scatter_idx, self.gather_idx)
+ key_layer = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx)
+ value_layer = SeqAllToAll4D.apply(self.ulysses_pg, value, self.scatter_idx, self.gather_idx)
+
+ if attention_mask is not None:
+ new_attention_mask = torch.cat([attention_mask] * dist.get_world_size(self.ulysses_pg), dim=1)
+
+ out = self.ring_attn_fn(
+ query_layer,
+ key_layer,
+ value_layer,
+ *args,
+ attention_mask=new_attention_mask,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ seqlens_in_batch=seqlens_in_batch,
+ causal=causal,
+ group=self.ring_pg,
+ )
+ else:
+ out = self.ring_attn_fn(
+ query_layer,
+ key_layer,
+ value_layer,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ group=self.ring_pg,
+ )
+
+ if type(out) == tuple:
+ context_layer, _, _ = out
+ else:
+ context_layer = out
+
+ # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
+ # scatter 1, gather 2
+ output = SeqAllToAll4D.apply(self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx)
+
+ # out e.g., [s/p::h]
+ return output
+
+
+# TODO (Qinghao): To be supported
+class HybridAttentionQKVPacked(torch.nn.Module):
+ """Initialization.
+
+ Arguments:
+ ulysses_pg (ProcessGroup): ulysses process group
+ ring_pg (ProcessGroup): ring process group
+ scatter_idx (int): scatter_idx for all2all comm
+ gather_idx (int): gather_idx for all2all comm
+ """
+
+ def __init__(
+ self,
+ scatter_idx: int = 3,
+ gather_idx: int = 1,
+ ring_impl_type: str = "zigzag",
+ ) -> None:
+
+ super().__init__()
+
+ self.ring_pg = get_ring_sp_pg()
+ self.ulysses_pg = get_ulysses_sp_pg()
+
+ assert (
+ self.ulysses_pg is not None or self.ring_pg is not None
+ ), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
+ self.scatter_idx = scatter_idx
+ self.gather_idx = gather_idx
+
+ self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type]
+
+ def forward(
+ self,
+ qkv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ *args: Any,
+ ) -> Tensor:
+ """forward
+
+ Arguments:
+ query (Tensor): query input to the layer
+ key (Tensor): key input to the layer
+ value (Tensor): value input to the layer
+ args: other args
+
+ Returns:
+ * output (Tensor): context output
+ """
+
+ # scatter 3, gather 1
+
+ world_size = dist.get_world_size(self.ulysses_pg)
+
+ if world_size > 1 and dist.is_initialized():
+ qkv = SeqAllToAll5D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx)
+
+ out = self.ring_attn_fn(
+ qkv,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ group=self.ring_pg,
+ )
+
+ # print(f"out {out.shape}")
+
+ if type(out) == tuple:
+ out = out[0]
+
+ # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
+ # scatter 1, gather 2
+
+ if world_size > 1 and dist.is_initialized():
+ out = SeqAllToAll4D.apply(self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1)
+ # out e.g., [s/p::h]
+ return out
+
+
+# TODO (Qinghao): To be supported
+class AsyncHybridAttention(torch.nn.Module):
+ """Initialization.
+
+ Arguments:
+ ulysses_pg (ProcessGroup): ulysses process group
+ ring_pg (ProcessGroup): ring process group
+ scatter_idx (int): scatter_idx for all2all comm
+ gather_idx (int): gather_idx for all2all comm
+ """
+
+ def __init__(
+ self,
+ scatter_idx: int = 2,
+ gather_idx: int = 1,
+ ring_impl_type: str = "zigzag",
+ ) -> None:
+
+ super().__init__()
+ self.ring_pg = get_ring_sp_pg()
+ self.ulysses_pg = get_ulysses_sp_pg()
+
+ self.stream = torch.cuda.Stream()
+ self._async_op = True
+
+ assert (
+ self.ulysses_pg is not None or self.ring_pg is not None
+ ), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
+ self.scatter_idx = scatter_idx
+ self.gather_idx = gather_idx
+ self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type]
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ *args: Any,
+ ) -> Tensor:
+ """forward
+
+ Arguments:
+ query (Tensor): query input to the layer (bs, seqlen/P, hc, hs)
+ key (Tensor): key input to the layer (bs, seqlen/P, hc_kv, hs)
+ value (Tensor): value input to the layer (bs, seqlen/P, hc_kv, hs)
+ args: other args
+
+ Returns:
+ * output (Tensor): context output
+ """
+
+ # un*ud = hc
+
+ ulysses_degree = dist.get_world_size(self.ulysses_pg)
+
+ bs, shard_seqlen, hc, hs = query.shape
+ bs, shard_seqlen, hc_kv, hs = key.shape
+ seq_len = shard_seqlen * ulysses_degree
+ un = hc // ulysses_degree
+ un_kv = hc_kv // ulysses_degree
+
+ assert un_kv == un, f"un_kv {un_kv} un {un}"
+
+ qkv = torch.cat([query, key, value]).contiguous()
+ # (3*bs, seqlen/P, hc, hs) -> (hc, seqlen/P, 3*bs, hs) -> (un, ud, seqlen/P, 3*bs, hs), where hc = un*ud
+ qkv_list = torch.unbind(qkv.transpose(0, 2).contiguous().reshape(un, ulysses_degree, shard_seqlen, 3 * bs, hs))
+ # 3xall-to-all output buffer
+ qkv_trans_list = [
+ torch.zeros(
+ ulysses_degree,
+ 1,
+ shard_seqlen,
+ 3 * bs,
+ hs,
+ dtype=query.dtype,
+ device=query.device,
+ )
+ for i in range(len(qkv_list))
+ ]
+ # last all-to-all buffter
+ context_layer_list = [
+ torch.zeros(
+ ulysses_degree,
+ 1,
+ shard_seqlen,
+ bs,
+ hs,
+ dtype=query.dtype,
+ device=query.device,
+ )
+ for i in range(len(qkv_list))
+ ]
+
+ comm_handle_list = []
+
+ # un * (ud, shard_seqlen, 3*bs, hs)
+ for i, qkv in enumerate(qkv_list):
+ with torch.cuda.stream(self.stream):
+ ret = dist.all_to_all_single(
+ qkv_trans_list[i],
+ qkv,
+ group=self.ulysses_pg,
+ async_op=self._async_op,
+ )
+ comm_handle_list.append(ret)
+
+ last_comm_handle_list = []
+ for i, qkv_trans in enumerate(qkv_trans_list):
+ if comm_handle_list[i] is not None:
+ comm_handle_list[i].wait()
+ qkv_trans = (
+ qkv_trans.reshape(seq_len, 3 * bs, 1, hs).transpose(0, 1).contiguous().reshape(3 * bs, seq_len, 1, hs)
+ )
+
+ # qkv_trans = all_to_all_4D_async(qkv, qkv_trans_list[i], self.scatter_idx, self.gather_idx, self.ulysses_pg)
+ qkv_trans = torch.chunk(qkv_trans, 3, dim=0)
+
+ out = self.ring_attn_fn(
+ qkv_trans[0],
+ qkv_trans[1],
+ qkv_trans[2],
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ group=self.ring_pg,
+ )
+
+ if type(out) == tuple:
+ context_layer, _, _ = out
+ else:
+ context_layer = out
+
+ # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
+ # scatter 1, gather 2
+
+ context_layer = (
+ context_layer.reshape(bs, ulysses_degree, shard_seqlen, 1, hs)
+ .transpose(0, 3)
+ .transpose(0, 1)
+ .contiguous()
+ .reshape(ulysses_degree, 1, shard_seqlen, bs, hs)
+ )
+ with torch.cuda.stream(self.stream):
+ ret = dist.all_to_all_single(
+ context_layer_list[i],
+ context_layer,
+ group=self.ulysses_pg,
+ async_op=self._async_op,
+ )
+ last_comm_handle_list.append(ret)
+
+ # hc = un * P
+ # un x (hc = P, seq_len/P, bs, hs) -> (bs, seq_len, hc = P, hs)
+ for i, ret in enumerate(last_comm_handle_list):
+ if ret is not None:
+ ret.wait()
+ context_layer_list[i] = (
+ context_layer_list[i]
+ .reshape(ulysses_degree, shard_seqlen, bs, hs)
+ .transpose(0, 2)
+ .contiguous()
+ .reshape(bs, shard_seqlen, ulysses_degree, hs)
+ )
+
+ output = torch.cat(context_layer_list, dim=2)
+ return output
+
+ def backward(self, *args, **kwargs):
+ raise RuntimeError("Backward computation is not allowed for AsyncHybridAttention.")
diff --git a/llava/train/sequence_parallel/input_utils.py b/llava/train/sequence_parallel/input_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..edcc4876ba02732bd11a18d541a87fe39656747a
--- /dev/null
+++ b/llava/train/sequence_parallel/input_utils.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def extract_local_zigzag(value, rank, world_size, device, dim=1):
+ value_chunks = value.chunk(2 * world_size, dim=dim)
+ local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim)
+ return local_value.to(device)
+
+
+def extract_local_from_list(value_list, sp_rank, sp_size):
+ quotient, remainder = divmod(len(value_list), sp_size)
+ start_idx = sp_rank * quotient + min(sp_rank, remainder)
+ end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder)
+ return value_list[start_idx:end_idx]
+
+
+def extract_local_from_list_zigzag(value_list, sp_rank, sp_size):
+ chunk_size, remainder = divmod(len(value_list), (2 * sp_size))
+ value_chunks = []
+ start_idx = 0
+ for i in range(2 * sp_size):
+ extra = 1 if i < remainder else 0
+ end_idx = start_idx + chunk_size + extra
+ value_chunks.append(value_list[start_idx:end_idx])
+ start_idx = end_idx
+
+ local_value = value_chunks[sp_rank] + value_chunks[2 * sp_size - sp_rank - 1]
+ return local_value
+
+
+def extract_local_input_ids(input_ids, image_positions, sp_rank, sp_size, bos_token_id=1, image_token_len=3):
+ quotient, remainder = divmod(len(image_positions), sp_size)
+ start_idx = sp_rank * quotient + min(sp_rank, remainder)
+ end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder)
+
+ start_position_idx = image_positions[start_idx]
+ if sp_rank != sp_size - 1:
+ end_position_idx = image_positions[end_idx]
+ else:
+ end_position_idx = len(input_ids)
+
+ if sp_rank == 0: # Handle the head of the sequence
+ return input_ids[0:end_position_idx]
+ elif sp_rank == sp_size - 1: # Handle the tail of the sequence
+ return input_ids[start_position_idx:]
+ else:
+ return input_ids[start_position_idx:end_position_idx]
+
+
+def extract_local_position_ids(input_ids, image_positions, image_ids, sp_rank, sp_size, image_token_len=198):
+ quotient, remainder = divmod(len(image_ids), sp_size)
+ start_idx = sp_rank * quotient + min(sp_rank, remainder)
+ end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder)
+ start_position_idx = image_positions[start_idx] + image_ids[start_idx] * image_token_len
+ if sp_rank != sp_size - 1: # Handle the tail of the sequence
+ end_position_idx = image_positions[end_idx] + image_ids[end_idx] * image_token_len # image_token_len + 3
+ else:
+ end_position_idx = len(input_ids)
+ if sp_rank == 0: # Handle the head of the sequence
+ return input_ids[0:end_position_idx]
+ elif sp_rank == sp_size - 1: # Handle the tail of the sequence
+ return input_ids[start_position_idx:]
+ else:
+ return input_ids[start_position_idx:end_position_idx]
diff --git a/llava/train/sequence_parallel/monkey_patch.py b/llava/train/sequence_parallel/monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b571f7e6ea7710d06942655a205a02c913414acb
--- /dev/null
+++ b/llava/train/sequence_parallel/monkey_patch.py
@@ -0,0 +1,256 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import inspect
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import pad_input
+from transformers.cache_utils import Cache
+from transformers.modeling_flash_attention_utils import _upad_input
+from transformers.utils import is_flash_attn_greater_or_equal
+
+from llava.model.utils.packing import _get_unpad_data
+from llava.train.sequence_parallel.globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg
+from llava.train.sequence_parallel.hybrid_attn import HybridAttention
+from llava.train.sequence_parallel.ring import ring_flash_attn_varlen_func, zigzag_ring_flash_attn_varlen_func
+from llava.train.sequence_parallel.ulysses_attn import UlyssesAttention
+
+
+def _ulysses_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ query_length,
+ attention_mask=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ seqlens_in_batch=None,
+ causal=None,
+):
+ batch_size = query_states.shape[0]
+
+ # overwrite query_length with the actual length of the sequence after SP communciation
+ query_length = attention_mask.shape[1]
+
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=True,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+
+ return attn_output
+
+
+def _hybrid_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ query_length,
+ attention_mask=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ seqlens_in_batch=None,
+ causal=None,
+ group=None,
+):
+ batch_size = query_states.shape[0]
+
+ # overwrite query_length with the actual length of the sequence after SP communciation
+ query_length = attention_mask.shape[1]
+ _get_unpad_data.seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seq_lens = cu_seq_lens[0]
+
+ ring_type = get_ring_type()
+ if ring_type == "ring_varlen":
+ attn_output_unpad = ring_flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seq_lens,
+ max_seq_lens[0],
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=True,
+ group=group,
+ )
+ elif ring_type == "zigzag_ring_varlen":
+ attn_output_unpad = zigzag_ring_flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seq_lens,
+ max_seq_lens[0],
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=True,
+ group=group,
+ )
+ else:
+ raise ValueError(f"Invalid ring_type: {ring_type}")
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+
+ return attn_output
+
+
+def _flash_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ is_causal: bool,
+ dropout: float = 0.0,
+ position_ids: Optional[torch.Tensor] = None,
+ softmax_scale: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ use_top_left_mask: bool = False,
+ softcap: Optional[float] = None,
+ deterministic: bool = None,
+):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_top_left_mask (`bool`, defaults to `False`):
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
+ softcap (`float`, *optional*):
+ Softcap for the attention logits, used e.g. in gemma2.
+ deterministic (`bool`, *optional*):
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ """
+ if not use_top_left_mask:
+ causal = is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
+ causal = is_causal and query_length != 1
+
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
+ use_sliding_windows = (
+ "window_size" in list(inspect.signature(flash_attn_func).parameters)
+ and sliding_window is not None
+ and key_states.shape[1] > sliding_window
+ )
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
+
+ if is_flash_attn_greater_or_equal("2.4.1"):
+ if deterministic is None:
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ flash_kwargs["deterministic"] = deterministic
+
+ if softcap is not None:
+ flash_kwargs["softcap"] = softcap
+
+ ring_enabled = get_ring_sp_pg() is not None
+
+ if attention_mask is not None:
+ if ring_enabled:
+ attn_output = HybridAttention(attention_warper=_hybrid_attn_varlen_func)(
+ query_states,
+ key_states,
+ value_states,
+ query_length,
+ attention_mask=attention_mask,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ seqlens_in_batch=_get_unpad_data.seqlens_in_batch,
+ )
+ else:
+ attn_output = UlyssesAttention(_ulysses_attn_varlen_func, get_ulysses_sp_pg())(
+ query_states,
+ key_states,
+ value_states,
+ query_length,
+ attention_mask=attention_mask,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ seqlens_in_batch=_get_unpad_data.seqlens_in_batch,
+ )
+ else:
+ if ring_enabled:
+ attn_output = HybridAttention()(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=is_causal,
+ )
+ else:
+ attn_output = UlyssesAttention(flash_attn_func, get_ulysses_sp_pg())(
+ query_states,
+ key_states,
+ value_states,
+ query_length,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ )
+ return attn_output
+
+
+def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+):
+ return attention_mask
diff --git a/llava/train/sequence_parallel/ring/__init__.py b/llava/train/sequence_parallel/ring/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e654037cf76056a6fdcc44beb9468e8f3d0bacc
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+
+from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func
+from .ring_flash_attn_varlen import (
+ ring_flash_attn_varlen_func,
+ ring_flash_attn_varlen_kvpacked_func,
+ ring_flash_attn_varlen_qkvpacked_func,
+)
+from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func
+from .zigzag_ring_flash_attn import (
+ zigzag_ring_flash_attn_func,
+ zigzag_ring_flash_attn_kvpacked_func,
+ zigzag_ring_flash_attn_qkvpacked_func,
+)
+from .zigzag_ring_flash_attn_varlen import (
+ zigzag_ring_flash_attn_varlen_func,
+ zigzag_ring_flash_attn_varlen_qkvpacked_func,
+)
diff --git a/llava/train/sequence_parallel/ring/ring_flash_attn.py b/llava/train/sequence_parallel/ring/ring_flash_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6a796acce948abf33db942f2fc1da9d41b0f4
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/ring_flash_attn.py
@@ -0,0 +1,306 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+import torch
+from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
+
+from .utils import RingComm, update_out_and_lse
+
+
+def ring_flash_attn_forward(
+ process_group,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ comm = RingComm(process_group)
+
+ out = None
+ lse = None
+
+ next_k, next_v = None, None
+
+ for step in range(comm.world_size):
+ if step + 1 != comm.world_size:
+ next_k: torch.Tensor = comm.send_recv(k)
+ next_v: torch.Tensor = comm.send_recv(v)
+ comm.commit()
+
+ if not causal or step <= comm.rank:
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal=causal and step == 0,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ )
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+
+ if step + 1 != comm.world_size:
+ comm.wait()
+ k = next_k
+ v = next_v
+
+ out = out.to(q.dtype)
+ lse = lse.squeeze(dim=-1).transpose(1, 2)
+ return out, lse
+
+
+def ring_flash_attn_backward(
+ process_group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ kv_comm = RingComm(process_group)
+ d_kv_comm = RingComm(process_group)
+ dq, dk, dv = None, None, None
+ next_dk, next_dv = None, None
+
+ block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
+ block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
+ block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
+
+ next_dk, next_dv = None, None
+ next_k, next_v = None, None
+
+ for step in range(kv_comm.world_size):
+ if step + 1 != kv_comm.world_size:
+ next_k = kv_comm.send_recv(k)
+ next_v = kv_comm.send_recv(v)
+ kv_comm.commit()
+ if step <= kv_comm.rank or not causal:
+ bwd_causal = causal and step == 0
+ _flash_attn_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ block_dq_buffer,
+ block_dk_buffer,
+ block_dv_buffer,
+ dropout_p,
+ softmax_scale,
+ bwd_causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+
+ if dq is None:
+ dq = block_dq_buffer.to(torch.float32)
+ dk = block_dk_buffer.to(torch.float32)
+ dv = block_dv_buffer.to(torch.float32)
+ else:
+ dq += block_dq_buffer
+ d_kv_comm.wait()
+ dk = block_dk_buffer + next_dk
+ dv = block_dv_buffer + next_dv
+ elif step != 0:
+ d_kv_comm.wait()
+ dk = next_dk
+ dv = next_dv
+
+ if step + 1 != kv_comm.world_size:
+ kv_comm.wait()
+ k = next_k
+ v = next_v
+
+ next_dk = d_kv_comm.send_recv(dk)
+ next_dv = d_kv_comm.send_recv(dv)
+ d_kv_comm.commit()
+
+ d_kv_comm.wait()
+
+ return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
+
+
+class RingFlashAttnFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_softmax,
+ group,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ assert alibi_slopes is None
+ k = k.contiguous()
+ v = v.contiguous()
+ out, softmax_lse = ring_flash_attn_forward(
+ group,
+ q,
+ k,
+ v,
+ softmax_scale=softmax_scale,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=False,
+ )
+ # this should be out_padded
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+ ctx.group = group
+ return out if not return_softmax else (out, softmax_lse, None)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ q, k, v, out, softmax_lse = ctx.saved_tensors
+ dq, dk, dv = ring_flash_attn_backward(
+ ctx.group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale=ctx.softmax_scale,
+ dropout_p=ctx.dropout_p,
+ causal=ctx.causal,
+ window_size=ctx.window_size,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None
+
+
+def ring_flash_attn_qkvpacked_func(
+ qkv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnFunc.apply(
+ qkv[:, :, 0],
+ qkv[:, :, 1],
+ qkv[:, :, 2],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def ring_flash_attn_kvpacked_func(
+ q,
+ kv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnFunc.apply(
+ q,
+ kv[:, :, 0],
+ kv[:, :, 1],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def ring_flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnFunc.apply(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
diff --git a/llava/train/sequence_parallel/ring/ring_flash_attn_varlen.py b/llava/train/sequence_parallel/ring/ring_flash_attn_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd0600ba082f0e86a575a71f19e2de107ab6735
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/ring_flash_attn_varlen.py
@@ -0,0 +1,343 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+import torch
+from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward
+
+from .utils import RingComm, update_out_and_lse
+
+try:
+ from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse
+except:
+ from .utils import flatten_varlen_lse, unflatten_varlen_lse
+
+
+def ring_flash_attn_varlen_forward(
+ process_group,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens,
+ max_seqlen,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ comm = RingComm(process_group)
+
+ out = None
+ lse = None
+ next_k, next_v = None, None
+
+ for step in range(comm.world_size):
+ if step + 1 != comm.world_size:
+ next_k: torch.Tensor = comm.send_recv(k)
+ next_v: torch.Tensor = comm.send_recv(v)
+ comm.commit()
+ if not causal or step <= comm.rank:
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
+ q,
+ k,
+ v,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal=causal and step == 0,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ block_table=None,
+ )
+
+ block_lse = flatten_varlen_lse(block_lse, cu_seqlens=cu_seqlens)
+
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+
+ if step + 1 != comm.world_size:
+ comm.wait()
+ k = next_k
+ v = next_v
+
+ out = out.to(q.dtype)
+ lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
+ return out, lse
+
+
+def ring_flash_attn_varlen_backward(
+ process_group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ cu_seqlens,
+ max_seqlen,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ kv_comm = RingComm(process_group)
+ d_kv_comm = RingComm(process_group)
+ dq, dk, dv = None, None, None
+ next_dk, next_dv = None, None
+
+ block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
+ block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
+ block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
+
+ next_dk, next_dv = None, None
+ next_k, next_v = None, None
+ for step in range(kv_comm.world_size):
+ if step + 1 != kv_comm.world_size:
+ next_k = kv_comm.send_recv(k)
+ next_v = kv_comm.send_recv(v)
+ kv_comm.commit()
+ if step <= kv_comm.rank or not causal:
+ bwd_causal = causal and step == 0
+ _flash_attn_varlen_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ block_dq_buffer,
+ block_dk_buffer,
+ block_dv_buffer,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ bwd_causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+
+ if dq is None:
+ dq = block_dq_buffer.to(torch.float32)
+ dk = block_dk_buffer.to(torch.float32)
+ dv = block_dv_buffer.to(torch.float32)
+ else:
+ dq += block_dq_buffer
+ d_kv_comm.wait()
+ dk = block_dk_buffer + next_dk
+ dv = block_dv_buffer + next_dv
+ elif step != 0:
+ d_kv_comm.wait()
+ dk = next_dk
+ dv = next_dv
+
+ if step + 1 != kv_comm.world_size:
+ kv_comm.wait()
+ k = next_k
+ v = next_v
+
+ next_dk = d_kv_comm.send_recv(dk)
+ next_dv = d_kv_comm.send_recv(dv)
+ d_kv_comm.commit()
+
+ d_kv_comm.wait()
+
+ return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
+
+
+class RingFlashAttnVarlenFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_softmax,
+ group,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ assert alibi_slopes is None
+ k = k.contiguous()
+ v = v.contiguous()
+ out, softmax_lse = ring_flash_attn_varlen_forward(
+ group,
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ softmax_scale=softmax_scale,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=False,
+ )
+ # this should be out_padded
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens)
+ ctx.max_seqlen = max_seqlen
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+ ctx.group = group
+ return out if not return_softmax else (out, softmax_lse, None)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors
+ dq, dk, dv = ring_flash_attn_varlen_backward(
+ ctx.group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ cu_seqlens,
+ ctx.max_seqlen,
+ softmax_scale=ctx.softmax_scale,
+ dropout_p=ctx.dropout_p,
+ causal=ctx.causal,
+ window_size=ctx.window_size,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
+
+
+def ring_flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnVarlenFunc.apply(
+ qkv[:, 0],
+ qkv[:, 1],
+ qkv[:, 2],
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def ring_flash_attn_varlen_kvpacked_func(
+ q,
+ kv,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnVarlenFunc.apply(
+ q,
+ kv[:, 0],
+ kv[:, 1],
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def ring_flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return RingFlashAttnVarlenFunc.apply(
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
diff --git a/llava/train/sequence_parallel/ring/stripe_flash_attn.py b/llava/train/sequence_parallel/ring/stripe_flash_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..356030740c546984230c4ba004ac1319413ceb7e
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/stripe_flash_attn.py
@@ -0,0 +1,350 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Striped Attention Paper: https://arxiv.org/abs/2311.09431
+
+import torch
+from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
+
+from .utils import RingComm, update_out_and_lse
+
+
+def stripe_flash_attn_forward(
+ process_group,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead"
+ comm = RingComm(process_group)
+
+ out = None
+ lse = None
+
+ next_k, next_v = None, None
+
+ for step in range(comm.world_size):
+ if step + 1 != comm.world_size:
+ next_k: torch.Tensor = comm.send_recv(k)
+ next_v: torch.Tensor = comm.send_recv(v)
+ comm.commit()
+
+ if step <= comm.rank:
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ )
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+ else:
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
+ q[:, 1:],
+ k[:, :-1],
+ v[:, :-1],
+ dropout_p,
+ softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ )
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None)))
+
+ if step + 1 != comm.world_size:
+ comm.wait()
+ k = next_k
+ v = next_v
+
+ out = out.to(q.dtype)
+ lse = lse.squeeze(dim=-1).transpose(1, 2)
+ return out, lse
+
+
+def stripe_flash_attn_backward(
+ process_group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead"
+ kv_comm = RingComm(process_group)
+ d_kv_comm = RingComm(process_group)
+ dq, dk, dv = None, None, None
+ next_dk, next_dv = None, None
+ next_k, next_v = None, None
+ dk_comm_buffer, dv_comm_buffer = None, None
+
+ block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
+ block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
+ block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
+ for step in range(kv_comm.world_size):
+ if step + 1 != kv_comm.world_size:
+ next_k = kv_comm.send_recv(k)
+ next_v = kv_comm.send_recv(v)
+ kv_comm.commit()
+
+ shift_causal = step > kv_comm.rank
+ softmax_lse_1 = None
+ if not shift_causal:
+ _flash_attn_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ block_dq_buffer,
+ block_dk_buffer,
+ block_dv_buffer,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+ else:
+ if softmax_lse_1 is None:
+ # lazy init, since the last rank does not need softmax_lse_1
+ softmax_lse_1 = softmax_lse[:, :, 1:].contiguous()
+ _flash_attn_backward(
+ dout[:, 1:],
+ q[:, 1:],
+ k[:, :-1],
+ v[:, :-1],
+ out[:, 1:],
+ softmax_lse_1,
+ block_dq_buffer[:, 1:],
+ block_dk_buffer[:, :-1],
+ block_dv_buffer[:, :-1],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+
+ if dq is None:
+ dq = block_dq_buffer.to(torch.float32)
+ dk = block_dk_buffer.to(torch.float32)
+ dv = block_dv_buffer.to(torch.float32)
+ else:
+ if not shift_causal:
+ dq += block_dq_buffer
+ else:
+ dq[:, 1:] += block_dq_buffer[:, 1:]
+ d_kv_comm.wait()
+ dk_comm_buffer, dv_comm_buffer = dk, dv
+ dk = next_dk
+ dv = next_dv
+
+ if not shift_causal:
+ dk = block_dk_buffer + dk
+ dv = block_dv_buffer + dv
+ else:
+ dk[:, :-1] += block_dk_buffer[:, :-1]
+ dv[:, :-1] += block_dv_buffer[:, :-1]
+
+ if step + 1 != kv_comm.world_size:
+ kv_comm.wait()
+ k = next_k
+ v = next_v
+
+ next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
+ next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
+ d_kv_comm.commit()
+
+ d_kv_comm.wait()
+
+ return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
+
+
+class StripeFlashAttnFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_softmax,
+ group,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ assert alibi_slopes is None
+ k = k.contiguous()
+ v = v.contiguous()
+ out, softmax_lse = stripe_flash_attn_forward(
+ group,
+ q,
+ k,
+ v,
+ softmax_scale=softmax_scale,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=False,
+ )
+ # this should be out_padded
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+ ctx.group = group
+ return out if not return_softmax else (out, softmax_lse, None)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ q, k, v, out, softmax_lse = ctx.saved_tensors
+ dq, dk, dv = stripe_flash_attn_backward(
+ ctx.group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale=ctx.softmax_scale,
+ dropout_p=ctx.dropout_p,
+ causal=ctx.causal,
+ window_size=ctx.window_size,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None
+
+
+def stripe_flash_attn_qkvpacked_func(
+ qkv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return StripeFlashAttnFunc.apply(
+ qkv[:, :, 0],
+ qkv[:, :, 1],
+ qkv[:, :, 2],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def stripe_flash_attn_kvpacked_func(
+ q,
+ kv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return StripeFlashAttnFunc.apply(
+ q,
+ kv[:, :, 0],
+ kv[:, :, 1],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def stripe_flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return StripeFlashAttnFunc.apply(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
diff --git a/llava/train/sequence_parallel/ring/triton_utils.py b/llava/train/sequence_parallel/ring/triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bc3641918f3b3f148c88e199a9121af25694c96
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/triton_utils.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def flatten_kernel(
+ # pointers to matrices
+ OUT,
+ LSE,
+ CU_SEQLENS,
+ # strides
+ stride_out_nheads,
+ stride_out_seqlen,
+ stride_lse_batch,
+ stride_lse_nheads,
+ stride_lse_seqlen,
+ # meta-parameters
+ BLOCK_M: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads
+ OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+ LSE = LSE + rm[:, None] * stride_lse_seqlen
+ x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
+
+ OUT = OUT + rm[:, None] * stride_out_seqlen
+ tl.store(OUT, x, mask=rm[:, None] < seqlen)
+
+
+def flatten_varlen_lse(lse, cu_seqlens):
+ """
+ Arguments:
+ lse: (batch_size, nheads, max_seqlen)
+ cu_seqlens: (batch_size + 1,)
+ Return:
+ flatten_lse: (nheads, total_seqlen)
+ """
+ total_seqlen = cu_seqlens[-1]
+ batch_size, nheads, max_seqlen = lse.shape
+ output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device)
+
+ grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
+ BLOCK_M = 4
+
+ with torch.cuda.device(lse.device.index):
+ flatten_kernel[grid](
+ output,
+ lse,
+ cu_seqlens,
+ # strides
+ output.stride(0),
+ output.stride(1),
+ lse.stride(0),
+ lse.stride(1),
+ lse.stride(2),
+ BLOCK_M,
+ )
+ return output
+
+
+@triton.jit
+def unflatten_kernel(
+ # pointers to matrices
+ OUT,
+ LSE,
+ CU_SEQLENS,
+ # strides
+ stride_out_batch,
+ stride_out_nheads,
+ stride_out_seqlen,
+ stride_lse_seqlen,
+ stride_lse_nheads,
+ # meta-parameters
+ BLOCK_M: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+ LSE = LSE + rm[:, None] * stride_lse_seqlen
+ x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
+
+ OUT = OUT + rm[:, None] * stride_out_seqlen
+ tl.store(OUT, x, mask=rm[:, None] < seqlen)
+
+
+def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
+ """
+ Arguments:
+ lse: (total_seqlen, nheads, 1)
+ cu_seqlens: (batch_size + 1,)
+ max_seqlen: int
+ Return:
+ unflatten_lse: (batch_size, nheads, max_seqlen)
+ """
+ lse = lse.unsqueeze(dim=-1)
+ batch_size = len(cu_seqlens) - 1
+ nheads = lse.shape[1]
+ output = torch.empty(
+ (batch_size, nheads, max_seqlen),
+ dtype=lse.dtype,
+ device=lse.device,
+ )
+
+ grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
+ BLOCK_M = 4
+
+ with torch.cuda.device(lse.device.index):
+ unflatten_kernel[grid](
+ output,
+ lse,
+ cu_seqlens,
+ # strides
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ lse.stride(0),
+ lse.stride(1),
+ BLOCK_M,
+ )
+ return output
diff --git a/llava/train/sequence_parallel/ring/utils.py b/llava/train/sequence_parallel/ring/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f967d13381398ad7906c445818c3634d92043be9
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/utils.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+from typing import Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+__all__ = ["update_out_and_lse", "RingComm"]
+
+
+@torch.jit.script
+def _update_out_and_lse(
+ out: torch.Tensor,
+ lse: torch.Tensor,
+ block_out: torch.Tensor,
+ block_lse: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ block_out = block_out.to(torch.float32)
+ block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
+
+ out = out - F.sigmoid(block_lse - lse) * (out - block_out)
+ lse = lse - F.logsigmoid(lse - block_lse)
+
+ return out, lse
+
+
+def update_out_and_lse(
+ out: Optional[torch.Tensor],
+ lse: Optional[torch.Tensor],
+ block_out: torch.Tensor,
+ block_lse: torch.Tensor,
+ slice_=None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ if out is None:
+ if slice_ is not None:
+ raise RuntimeError("first update_out_and_lse should not pass slice_ args")
+ out = block_out.to(torch.float32)
+ lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
+ elif slice_ is not None:
+ slice_out, slice_lse = out[slice_], lse[slice_]
+ slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
+ out[slice_], lse[slice_] = slice_out, slice_lse
+ else:
+ out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
+ return out, lse
+
+
+@torch.jit.script
+def flatten_varlen_lse(lse, cu_seqlens):
+ new_lse = []
+ for i in range(len(cu_seqlens) - 1):
+ start, end = cu_seqlens[i], cu_seqlens[i + 1]
+ new_lse.append(lse[i, :, : end - start])
+ return torch.cat(new_lse, dim=1)
+
+
+@torch.jit.script
+def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
+ num_seq = len(cu_seqlens) - 1
+ num_head = lse.shape[-2]
+ new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device)
+ for i in range(num_seq):
+ start, end = cu_seqlens[i], cu_seqlens[i + 1]
+ new_lse[i, : end - start] = lse[start:end]
+ return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()
+
+
+class RingComm:
+ def __init__(self, process_group: dist.ProcessGroup):
+ self._process_group = process_group
+ self._ops = []
+ self.rank = dist.get_rank(self._process_group)
+ self.world_size = dist.get_world_size(self._process_group)
+ self._reqs = None
+
+ self.send_rank = (self.rank + 1) % self.world_size
+ self.recv_rank = (self.rank - 1) % self.world_size
+
+ if process_group is not None:
+ self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
+ self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
+
+ def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if recv_tensor is None:
+ res = torch.empty_like(to_send)
+ else:
+ res = recv_tensor
+
+ send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
+ recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
+ self._ops.append(send_op)
+ self._ops.append(recv_op)
+ return res
+
+ def commit(self):
+ if self._reqs is not None:
+ raise RuntimeError("commit called twice")
+ self._reqs = dist.batch_isend_irecv(self._ops)
+
+ def wait(self):
+ if self._reqs is None:
+ raise RuntimeError("wait called before commit")
+ for req in self._reqs:
+ req.wait()
+ self._reqs = None
+ self._ops = []
diff --git a/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn.py b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a0450bfe0a907c2c588fca8442b53180c39ca4
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn.py
@@ -0,0 +1,349 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+import torch
+from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
+
+from .utils import RingComm, update_out_and_lse
+
+
+def zigzag_ring_flash_attn_forward(
+ process_group,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal == True, "zigzag ring is meaningless for causal=False"
+ comm = RingComm(process_group)
+
+ block_seq_len = q.shape[1] // 2
+ q1 = q[:, block_seq_len:]
+
+ out = None
+ lse = None
+ next_k, next_v = None, None
+
+ def forward(q, k, v, causal):
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ )
+ return block_out, block_lse
+
+ for step in range(comm.world_size):
+ if step + 1 != comm.world_size:
+ next_k: torch.Tensor = comm.send_recv(k)
+ next_v: torch.Tensor = comm.send_recv(v)
+ comm.commit()
+
+ if step == 0:
+ block_out, block_lse = forward(q, k, v, causal=True)
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+ elif step <= comm.rank:
+ k0 = k[:, :block_seq_len]
+ v0 = v[:, :block_seq_len]
+ block_out, block_lse = forward(q, k0, v0, causal=False)
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+ else:
+ block_out, block_lse = forward(q1, k, v, causal=False)
+ out, lse = update_out_and_lse(
+ out,
+ lse,
+ block_out,
+ block_lse,
+ slice_=(slice(None), slice(block_seq_len, None)),
+ )
+
+ if step + 1 != comm.world_size:
+ comm.wait()
+ k = next_k
+ v = next_v
+
+ out = out.to(q.dtype)
+ lse = lse.squeeze(dim=-1).transpose(1, 2)
+ return out, lse
+
+
+def zigzag_ring_flash_attn_backward(
+ process_group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal == True, "zigzag ring is meaningless for causal=False"
+ kv_comm = RingComm(process_group)
+ d_kv_comm = RingComm(process_group)
+ dq, dk, dv = None, None, None
+ next_dk, next_dv = None, None
+ next_k, next_v = None, None
+ dk_comm_buffer, dv_comm_buffer = None, None
+
+ dout1 = dout.chunk(2, dim=1)[1]
+ q1 = q.chunk(2, dim=1)[1]
+ out1 = out.chunk(2, dim=1)[1]
+ softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
+ block_seq_len = q.shape[1] // 2
+
+ # repeatly allocating buffer may be slow...
+ dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
+ dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
+ dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
+
+ def backward(dout, q, k, v, out, softmax_lse, causal):
+ seqlen_q = q.shape[1]
+ seqlen_kv = k.shape[1]
+ _flash_attn_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq_buffer[:, :seqlen_q],
+ dk_buffer[:, :seqlen_kv],
+ dv_buffer[:, :seqlen_kv],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+
+ for step in range(kv_comm.world_size):
+ if step + 1 != kv_comm.world_size:
+ next_k = kv_comm.send_recv(k)
+ next_v = kv_comm.send_recv(v)
+ kv_comm.commit()
+
+ if step == 0:
+ backward(dout, q, k, v, out, softmax_lse, causal=True)
+ dq = dq_buffer.to(torch.float32)
+ dk = dk_buffer.to(torch.float32)
+ dv = dv_buffer.to(torch.float32)
+ else:
+ if step <= kv_comm.rank:
+ k0 = k[:, :block_seq_len]
+ v0 = v[:, :block_seq_len]
+ backward(dout, q, k0, v0, out, softmax_lse, causal=False)
+ dq += dq_buffer
+ else:
+ backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
+ # always use the first half in dq_buffer.
+ dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
+
+ d_kv_comm.wait()
+ dk_comm_buffer, dv_comm_buffer = dk, dv
+ dk, dv = next_dk, next_dv
+
+ if step <= kv_comm.rank:
+ dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
+ dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
+ else:
+ dk += dk_buffer
+ dv += dv_buffer
+
+ if step + 1 != kv_comm.world_size:
+ kv_comm.wait()
+ k = next_k
+ v = next_v
+
+ next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
+ next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
+ d_kv_comm.commit()
+
+ d_kv_comm.wait()
+
+ return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
+
+
+class ZigZagRingFlashAttnFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_softmax,
+ group,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ assert alibi_slopes is None
+ k = k.contiguous()
+ v = v.contiguous()
+ out, softmax_lse = zigzag_ring_flash_attn_forward(
+ group,
+ q,
+ k,
+ v,
+ softmax_scale=softmax_scale,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=False,
+ )
+ # this should be out_padded
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+ ctx.group = group
+ return out if not return_softmax else (out, softmax_lse, None)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ q, k, v, out, softmax_lse = ctx.saved_tensors
+ dq, dk, dv = zigzag_ring_flash_attn_backward(
+ ctx.group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ softmax_scale=ctx.softmax_scale,
+ dropout_p=ctx.dropout_p,
+ causal=ctx.causal,
+ window_size=ctx.window_size,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None
+
+
+def zigzag_ring_flash_attn_qkvpacked_func(
+ qkv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnFunc.apply(
+ qkv[:, :, 0],
+ qkv[:, :, 1],
+ qkv[:, :, 2],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def zigzag_ring_flash_attn_kvpacked_func(
+ q,
+ kv,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnFunc.apply(
+ q,
+ kv[:, :, 0],
+ kv[:, :, 1],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def zigzag_ring_flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnFunc.apply(
+ q,
+ k,
+ v,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
diff --git a/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..73cc5cf39a84b8067d85ca9fc6a15f6c757a11fc
--- /dev/null
+++ b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py
@@ -0,0 +1,467 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Adopted from https://github.com/zhuzilin/ring-flash-attention.
+# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889
+
+import torch
+from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward
+
+from .utils import RingComm, update_out_and_lse
+
+try:
+ from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse
+except:
+ from .utils import flatten_varlen_lse, unflatten_varlen_lse
+
+
+def get_half_index(cu_seqlens, *, front: bool):
+ if len(cu_seqlens) == 2:
+ if front:
+ return slice(None, cu_seqlens[-1] // 2)
+ else:
+ return slice(cu_seqlens[-1] // 2, None)
+
+ index = torch.zeros((cu_seqlens[-1],), dtype=bool)
+ for i in range(len(cu_seqlens) - 1):
+ start, end = cu_seqlens[i], cu_seqlens[i + 1]
+ if front:
+ end = (start + end) // 2
+ else:
+ start = (start + end) // 2
+ index[start:end] = True
+ return index
+
+
+@torch.jit.script
+def get_half_lse(lse, cu_seqlens, *, front: bool):
+ new_lse = torch.empty(
+ (lse.shape[0], lse.shape[1], lse.shape[2] // 2),
+ dtype=lse.dtype,
+ device=lse.device,
+ )
+ for i in range(len(cu_seqlens) - 1):
+ seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item()
+ if front:
+ start, end = 0, seqlen // 2
+ else:
+ start, end = seqlen // 2, seqlen
+ new_lse[i, :, : seqlen // 2] = lse[i, :, start:end]
+ return new_lse
+
+
+def zigzag_ring_flash_attn_varlen_forward(
+ process_group,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens,
+ max_seqlen,
+ half_index0,
+ half_index1,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal == True, "zigzag ring is meaningless for causal=False"
+ comm = RingComm(process_group)
+
+ block_seq_len = q.shape[0] // 2
+ q1 = q[half_index1]
+
+ out = None
+ lse = None
+ next_k, next_v = None, None
+ half_cu_seqlens = cu_seqlens // 2
+ half_max_seqlen = max_seqlen // 2
+
+ def forward(q, k, v, causal):
+ seqlen_q = q.shape[0]
+ seqlen_kv = k.shape[0]
+ cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens
+ max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen
+ cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens
+ max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen
+ block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
+ q,
+ k,
+ v,
+ # the first half and the second half are the same
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ dropout_p,
+ softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True and dropout_p > 0,
+ block_table=None,
+ )
+ return block_out, block_lse
+
+ for step in range(comm.world_size):
+ if step + 1 != comm.world_size:
+ next_k: torch.Tensor = comm.send_recv(k)
+ next_v: torch.Tensor = comm.send_recv(v)
+ comm.commit()
+
+ if step == 0:
+ block_out, block_lse = forward(q, k, v, causal=True)
+ block_lse = flatten_varlen_lse(
+ block_lse,
+ cu_seqlens=cu_seqlens,
+ )
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+ elif step <= comm.rank:
+ k0 = k[half_index0]
+ v0 = v[half_index0]
+ block_out, block_lse = forward(q, k0, v0, causal=False)
+ block_lse = flatten_varlen_lse(
+ block_lse,
+ cu_seqlens=cu_seqlens,
+ )
+ out, lse = update_out_and_lse(out, lse, block_out, block_lse)
+ else:
+ block_out, block_lse = forward(q1, k, v, causal=False)
+ block_lse = flatten_varlen_lse(
+ block_lse,
+ cu_seqlens=half_cu_seqlens,
+ )
+ out[half_index1], lse[half_index1] = update_out_and_lse(
+ out[half_index1], lse[half_index1], block_out, block_lse
+ )
+
+ if step + 1 != comm.world_size:
+ comm.wait()
+ k = next_k
+ v = next_v
+
+ out = out.to(q.dtype)
+ lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
+ return out, lse
+
+
+def zigzag_ring_flash_attn_varlen_backward(
+ process_group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ cu_seqlens,
+ max_seqlen,
+ half_index0,
+ half_index1,
+ softmax_scale,
+ dropout_p=0,
+ causal=True,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+):
+ assert causal == True, "zigzag ring is meaningless for causal=False"
+ kv_comm = RingComm(process_group)
+ d_kv_comm = RingComm(process_group)
+ dq, dk, dv = None, None, None
+ next_dk, next_dv = None, None
+ next_k, next_v = None, None
+ dk_comm_buffer, dv_comm_buffer = None, None
+
+ dout1 = dout[half_index1]
+ q1 = q[half_index1]
+ out1 = out[half_index1]
+ softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False)
+ block_seq_len = q.shape[0] // 2
+
+ half_cu_seqlens = cu_seqlens // 2
+ half_max_seqlen = max_seqlen // 2
+
+ # repeatly allocating buffer may be slow...
+ dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
+ dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
+ dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
+
+ def backward(dout, q, k, v, out, softmax_lse, causal):
+ seqlen_q = q.shape[0]
+ seqlen_kv = k.shape[0]
+ cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens
+ max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen
+ cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens
+ max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen
+ _flash_attn_varlen_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq_buffer[:seqlen_q],
+ dk_buffer[:seqlen_kv],
+ dv_buffer[:seqlen_kv],
+ # the first half and the second half are the same
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ rng_state=None,
+ )
+
+ for step in range(kv_comm.world_size):
+ if step + 1 != kv_comm.world_size:
+ next_k = kv_comm.send_recv(k)
+ next_v = kv_comm.send_recv(v)
+ kv_comm.commit()
+
+ if step == 0:
+ backward(dout, q, k, v, out, softmax_lse, causal=True)
+ dq = dq_buffer.to(torch.float32)
+ dk = dk_buffer.to(torch.float32)
+ dv = dv_buffer.to(torch.float32)
+ else:
+ if step <= kv_comm.rank:
+ k0 = k[half_index0]
+ v0 = v[half_index0]
+ backward(dout, q, k0, v0, out, softmax_lse, causal=False)
+ dq += dq_buffer
+ else:
+ backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
+ dq[half_index1] += dq_buffer[:block_seq_len]
+
+ d_kv_comm.wait()
+ dk_comm_buffer, dv_comm_buffer = dk, dv
+ dk, dv = next_dk, next_dv
+
+ if step <= kv_comm.rank:
+ dk[half_index0] += dk_buffer[:block_seq_len]
+ dv[half_index0] += dv_buffer[:block_seq_len]
+ else:
+ dk += dk_buffer
+ dv += dv_buffer
+
+ if step + 1 != kv_comm.world_size:
+ kv_comm.wait()
+ k = next_k
+ v = next_v
+
+ next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
+ next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
+ d_kv_comm.commit()
+
+ d_kv_comm.wait()
+
+ return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
+
+
+class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_softmax,
+ group,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ assert alibi_slopes is None
+ k = k.contiguous()
+ v = v.contiguous()
+ half_index0 = get_half_index(cu_seqlens, front=True)
+ half_index1 = get_half_index(cu_seqlens, front=False)
+ out, softmax_lse = zigzag_ring_flash_attn_varlen_forward(
+ group,
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ half_index0,
+ half_index1,
+ softmax_scale=softmax_scale,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ deterministic=False,
+ )
+ # this should be out_padded
+ is_half_index_tensor = isinstance(half_index0, torch.Tensor)
+ ctx.is_half_index_tensor = is_half_index_tensor
+ if is_half_index_tensor:
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1)
+ else:
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens)
+ ctx.half_index0 = half_index0
+ ctx.half_index1 = half_index1
+ ctx.max_seqlen = max_seqlen
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+ ctx.group = group
+ return out if not return_softmax else (out, softmax_lse, None)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ if ctx.is_half_index_tensor:
+ (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors
+ else:
+ q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors
+ half_index0 = ctx.half_index0
+ half_index1 = ctx.half_index1
+ dq, dk, dv = zigzag_ring_flash_attn_varlen_backward(
+ ctx.group,
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ cu_seqlens,
+ ctx.max_seqlen,
+ half_index0,
+ half_index1,
+ softmax_scale=ctx.softmax_scale,
+ dropout_p=ctx.dropout_p,
+ causal=ctx.causal,
+ window_size=ctx.window_size,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
+
+
+def zigzag_ring_flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnVarlenFunc.apply(
+ qkv[:, 0],
+ qkv[:, 1],
+ qkv[:, 2],
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def zigzag_ring_flash_attn_varlen_kvpacked_func(
+ q,
+ kv,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnVarlenFunc.apply(
+ q,
+ kv[:, 0],
+ kv[:, 1],
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
+
+
+def zigzag_ring_flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ group=None,
+):
+ return ZigZagRingFlashAttnVarlenFunc.apply(
+ q,
+ k,
+ v,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ alibi_slopes,
+ deterministic,
+ return_attn_probs,
+ group,
+ )
diff --git a/llava/train/sequence_parallel/ulysses_attn.py b/llava/train/sequence_parallel/ulysses_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c103bc1b48a0c8863376c77286acb67fdf4db91a
--- /dev/null
+++ b/llava/train/sequence_parallel/ulysses_attn.py
@@ -0,0 +1,237 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This file is modified from https://github.com/feifeibear/long-context-attention
+# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
+# This file is also partly modified from https://github.com/microsoft/DeepSpeed
+# Implementation refers to Ulysses Paper: https://arxiv.org/abs/2309.14509
+
+import copy
+from typing import Any, Tuple
+
+import deepspeed.comm as dist
+import torch
+import torch.distributed as torch_dist
+from flash_attn import flash_attn_func
+from torch import Tensor
+from torch.nn import Module
+
+from llava.train.sequence_parallel.globals import get_ulysses_seq_len, get_ulysses_sp_rank, get_ulysses_sp_size
+
+from .all_to_all import SeqAllGather, SeqAllToAll4D, SeqAllToAll5D
+
+
+class _ExpandKVFunction(torch.autograd.Function):
+ """
+ Copy the KV head repeat times to extend sequence parallel support for Ulysses.
+
+ Args:
+ kv: input kv.
+ repeat_times: the repeat number of each head.
+ num_head_dim: the dimension of head number.
+ """
+
+ @staticmethod
+ def forward(ctx, k, v, repeat_times, num_head_dim):
+
+ kv_shape = k.shape
+ num_heads_kv = kv_shape[num_head_dim]
+
+ ctx.num_head_dim = num_head_dim
+ ctx.num_heads_kv = num_heads_kv
+
+ # here we construct a repeat index to indicate which dim should copy
+ repeat_index = [1] * k.ndim
+ repeat_index[num_head_dim] = repeat_times
+
+ # split the kv into head num splits
+ k_splits = torch.chunk(k, chunks=num_heads_kv, dim=num_head_dim)
+ v_splits = torch.chunk(v, chunks=num_heads_kv, dim=num_head_dim)
+ k_repeats, v_repeats = [], []
+ # for each split, we copy it to repeat_times copys.
+ for split in k_splits:
+ k_split_repeat = split.repeat(repeat_index)
+ k_repeats.append(k_split_repeat)
+
+ for split in v_splits:
+ v_split_repeat = split.repeat(repeat_index)
+ v_repeats.append(v_split_repeat)
+
+ return torch.cat(k_repeats, dim=num_head_dim), torch.cat(v_repeats, dim=num_head_dim)
+
+ @staticmethod
+ def backward(ctx, grad_output_k, grad_output_v):
+ """
+ For backward, we sum the copy head inside a query group.
+ """
+
+ num_head_dim = ctx.num_head_dim
+ num_heads_kv = ctx.num_heads_kv
+
+ # we split the grad into query groups splits.
+ grad_output_k_splits = torch.chunk(grad_output_k, chunks=num_heads_kv, dim=num_head_dim)
+ grad_output_v_splits = torch.chunk(grad_output_v, chunks=num_heads_kv, dim=num_head_dim)
+
+ grad_output_k_sums, grad_output_v_sums = [], []
+ # for each split, we sum the head
+ for grad_output_k_split in grad_output_k_splits:
+ grad_output_k_sum = grad_output_k_split.sum(dim=num_head_dim, keepdim=True)
+ grad_output_k_sums.append(grad_output_k_sum)
+
+ for grad_output_v_split in grad_output_v_splits:
+ grad_output_v_sum = grad_output_v_split.sum(dim=num_head_dim, keepdim=True)
+ grad_output_v_sums.append(grad_output_v_sum)
+
+ # then we concat the split sums on the num_head_dim dimension.
+ grad_k = torch.cat(grad_output_k_sums, dim=num_head_dim)
+ grad_v = torch.cat(grad_output_v_sums, dim=num_head_dim)
+
+ return grad_k, grad_v, None, None
+
+
+expandKV = _ExpandKVFunction.apply
+
+
+class UlyssesAttention(torch.nn.Module):
+ """Initialization.
+
+ Arguments:
+ local_attention (Module): local attention with q,k,v
+ sequence_process_group (ProcessGroup): sequence parallel process group
+ scatter_idx (int): scatter_idx for all2all comm
+ gather_idx (int): gather_idx for all2all comm
+ """
+
+ def __init__(
+ self,
+ local_attention: Module,
+ sequence_process_group: dist.ProcessGroup = None,
+ scatter_idx: int = 2,
+ gather_idx: int = 1,
+ ) -> None:
+
+ super().__init__()
+ self.local_attn = local_attention
+ self.spg = sequence_process_group
+ self.scatter_idx = scatter_idx
+ self.gather_idx = gather_idx
+ self.ulysses_degree = get_ulysses_sp_size()
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ *args: Any,
+ attention_mask=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ seqlens_in_batch=None,
+ causal=False,
+ window_size=(-1, -1),
+ alibi_slopes=None,
+ deterministic=False,
+ return_attn_probs=False,
+ ) -> Tensor:
+ """forward
+
+ Arguments:
+ query (Tensor): query input to the layer
+ key (Tensor): key input to the layer
+ value (Tensor): value input to the layer
+ args: other args
+
+ Returns:
+ * output (Tensor): context output
+ """
+ # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
+
+ # KV Replication for GQA
+ head_dim = 2
+ num_head_kv = key.shape[head_dim]
+ if self.ulysses_degree > num_head_kv:
+ assert self.ulysses_degree % num_head_kv == 0, "Ulysses require num_head_kv to be dividable by sp degree."
+ key, value = expandKV(key, value, self.ulysses_degree // num_head_kv, head_dim)
+
+ # scatter 2, gather 1
+ q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
+ k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
+ v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
+
+ if attention_mask is not None:
+ local_attention_mask = copy.deepcopy(attention_mask)
+ shard_seqlen = local_attention_mask.size(1)
+ ulysses_seq_len = get_ulysses_seq_len()
+ max_global_length = max(ulysses_seq_len)
+ global_attention_mask_list = []
+ for i in range(get_ulysses_sp_size()):
+ if i == get_ulysses_sp_rank():
+ global_attention_mask_list.append(
+ torch.cat(
+ [
+ local_attention_mask,
+ torch.zeros(
+ (local_attention_mask.size(0), max_global_length - shard_seqlen),
+ dtype=local_attention_mask.dtype,
+ device=local_attention_mask.device,
+ ),
+ ],
+ dim=1,
+ )
+ )
+ else:
+ global_attention_mask_list.append(
+ torch.zeros(
+ (local_attention_mask.size(0), max_global_length),
+ dtype=local_attention_mask.dtype,
+ device=local_attention_mask.device,
+ )
+ )
+
+ global_attention_mask = torch.stack(global_attention_mask_list, dim=0)
+ torch_dist.all_reduce(global_attention_mask, group=self.spg)
+ torch_dist.barrier(group=self.spg)
+ new_global_attention_mask_list = list(torch.unbind(global_attention_mask, dim=0))
+ # Unpad the global attention mask list and concatenate them
+ for i in range(len(new_global_attention_mask_list)):
+ new_global_attention_mask_list[i] = new_global_attention_mask_list[i][:, : ulysses_seq_len[i]]
+ global_attention_mask = torch.cat(new_global_attention_mask_list, dim=1)
+ context_layer = self.local_attn(
+ q,
+ k,
+ v,
+ *args,
+ attention_mask=global_attention_mask,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ seqlens_in_batch=seqlens_in_batch,
+ causal=causal,
+ )
+ else:
+ context_layer = self.local_attn(
+ q,
+ k,
+ v,
+ *args,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ if isinstance(context_layer, tuple):
+ context_layer = context_layer[0]
+
+ # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
+
+ # scatter 1, gather 2
+ output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
+
+ # out e.g., [s/p::h]
+ return output
diff --git a/llava/train/slurm_utils.py b/llava/train/slurm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8211ec167f6b61fc351b12c2602a8807688202e2
--- /dev/null
+++ b/llava/train/slurm_utils.py
@@ -0,0 +1,117 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+import time
+import warnings
+
+import requests
+import torch
+import transformers
+from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, get_last_checkpoint
+
+
+def get_rank():
+ if not torch.distributed.is_initialized():
+ return 0
+ return torch.distributed.get_rank()
+
+
+def get_local_rank():
+ if not torch.distributed.is_initialized():
+ return 0
+ num_gpus = torch.cuda.device_count()
+ return get_rank() % num_gpus
+
+
+def get_world_size():
+ if not torch.distributed.is_initialized():
+ return 1
+ return torch.distributed.get_world_size()
+
+
+class Timer:
+ def __init__(self):
+ self.start_time = None
+ self.elapsed_time = 0
+
+ def start(self):
+ self.start_time = time.time()
+
+ def reset(self):
+ self.start_time = None
+ self.elapsed_time = 0
+
+ def get_elapsed_time(self):
+ if self.start_time is not None:
+ return self.elapsed_time + (time.time() - self.start_time)
+
+
+timer = Timer()
+
+
+def set_timer():
+ timer.start()
+
+
+def rank_print(*s):
+ if not torch.distributed.is_initialized():
+ rank = 0
+ else:
+ rank = torch.distributed.get_rank()
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ print(f"[{current_time}] Rank {rank}:", *s)
+
+
+class TimeoutTerminateCallback(transformers.TrainerCallback):
+ def __init__(self, total_time_limit=240, pre_terminate_time=10):
+ self.total_time_limit = total_time_limit
+ self.pre_terminate_time = pre_terminate_time
+ elapsed_time = timer.get_elapsed_time()
+ rank_print(
+ f"Timer for terminate callback has been set.\nTotal limit: {total_time_limit}min\nPre terminate time: {pre_terminate_time}min elapsed_time: {elapsed_time}s"
+ )
+
+ self.time_to_kill = (total_time_limit - pre_terminate_time) * 60
+
+ def on_step_end(self, args, state, control, model, **kwargs):
+ elapsed_time = timer.get_elapsed_time()
+
+ if elapsed_time is None:
+ # no timer has been set
+ return control
+
+ if elapsed_time > self.time_to_kill:
+ rank_print("Timeout, start to save checkpoint....")
+ control.should_save = True
+ control.should_training_stop = True
+
+ return control
+
+ def on_train_end(self, args, state, control, **kwargs):
+ if state.global_step < state.max_steps:
+ exit(124)
diff --git a/llava/train/train.py b/llava/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..530ef608f3f0a77331af757a5680bd28c5fac640
--- /dev/null
+++ b/llava/train/train.py
@@ -0,0 +1,888 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import os
+import warnings
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Sequence
+
+import torch
+import transformers
+from torch.utils.data import Dataset
+from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, LlamaForCausalLM, set_seed
+from transformers.modeling_utils import unwrap_model
+
+import llava.data.dataset as dataset
+import llava.data.datasets_mixture as datasets_mixture
+from llava import conversation as conversation_lib
+from llava.constants import IGNORE_INDEX
+from llava.data import make_supervised_data_module
+from llava.mm_utils import process_image
+from llava.model import LlavaLlamaConfig, LlavaLlamaModel
+from llava.model.language_model.fp8linearqwen2 import Qwen2ForCausalLM # We need this line to register AutoConfig
+from llava.train.args import DataArguments, ModelArguments, TrainingArguments
+from llava.train.callbacks.autoresume_callback import AutoResumeCallback
+from llava.train.llava_trainer import LLaVATrainer, VILADPOTrainer
+from llava.train.sequence_parallel import set_pg_manager
+from llava.train.slurm_utils import TimeoutTerminateCallback
+from llava.train.utils import (
+ get_checkpoint_path,
+ mprint,
+ prepare_config_for_training,
+ unit_test_rope_scaling,
+)
+from llava.trl.trainer.utils import DPODataCollatorWithPadding
+
+local_rank = None
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+os.environ['TRANSFORMERS_CACHE'] = '.cache'
+os.environ['HF_HOME'] = '.cache'
+if "WANDB_PROJECT" not in os.environ:
+ os.environ["WANDB_PROJECT"] = "AF3"
+
+def get_nb_trainable_parameters(model) -> tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and the number of all parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ num_params = param.ds_numel
+
+ # Due to the design of 4bit linear layers from bitsandbytes
+ # one needs to multiply the number of parameters by 2 to get
+ # the correct number of parameters
+ if param.__class__.__name__ == "Params4bit":
+ if hasattr(param, "element_size"):
+ num_bytes = param.element_size()
+ elif not hasattr(param, "quant_storage"):
+ num_bytes = 1
+ else:
+ num_bytes = param.quant_storage.itemsize
+ num_params = num_params * 2 * num_bytes
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+
+ return trainable_params, all_param
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model, lora_llm, lora_vt, lora_st, lora_sot):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ["mm_projector","speech_mm_projector","sound_mm_projector", "vision_resampler"]
+ assert lora_llm or lora_vt, "Not applying LoRA to any of the modules..."
+
+ if not lora_llm:
+ multimodal_keywords += ["llm"]
+ if not lora_vt:
+ multimodal_keywords += ["vision_tower"]
+ if not lora_st:
+ multimodal_keywords += ["speech_tower"]
+ if not lora_sot:
+ multimodal_keywords += ["sound_tower"]
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ if not "lm_head" in name:
+ lora_module_names.add(name)
+ # names = name.split(".")
+ # lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ # if "lm_head" in lora_module_names: # needed for 16-bit
+ # lora_module_names.remove("lm_head")
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir, _internal_call=True)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def make_conv(prompt, answer):
+ return [
+ {
+ "from": "human",
+ "value": prompt,
+ },
+ {
+ "from": "gpt",
+ "value": answer,
+ },
+ ]
+
+
+@dataclass
+class DPODataCollator(DPODataCollatorWithPadding):
+ tokenizer: Any = None
+
+ def collate(self, batch):
+ # first, pad everything to the same length
+ # input_ids, labels = tuple([instance[key] for instance in instances]
+ # for key in ("input_ids", "labels"))
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
+ # input_ids,
+ # batch_first=True,
+ # padding_value=self.tokenizer.pad_token_id)
+ # labels = torch.nn.utils.rnn.pad_sequence(labels,
+ # batch_first=True,
+ # padding_value=IGNORE_INDEX)
+ # input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ # labels = labels[:, :self.tokenizer.model_max_length]
+ # batch = dict(
+ # input_ids=input_ids,
+ # labels=labels,
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ # )
+ padded_batch = {}
+ for k in batch[0].keys():
+ if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
+ # if "prompt" in k:
+ # to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
+ # else:
+ to_pad = [torch.LongTensor(ex[k]) for ex in batch]
+ if k.endswith("_input_ids"):
+ padding_value = self.pad_token_id
+ elif k.endswith("_labels"):
+ padding_value = self.label_pad_token_id
+ else:
+ continue
+ # elif k.endswith("_attention_mask"):
+ # padding_value = self.padding_value
+ # else:
+ # raise ValueError(f"Unexpected key in batch '{k}'")
+
+ padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
+ # for the prompt, flip back so padding is on left side
+ # if "prompt" in k:
+ # padded_batch[k] = padded_batch[k].flip(dims=[1])
+ else:
+ padded_batch[k] = [ex[k] for ex in batch]
+ for k in ["chosen_input_ids", "rejected_input_ids"]:
+ attn_k = k.replace("input_ids", "attention_mask")
+ padded_batch[attn_k] = padded_batch[k].ne(self.pad_token_id)
+ return padded_batch
+
+ def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str) -> Dict:
+ """Tokenize a single batch element.
+
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
+
+ We also create the labels for the chosen/rejected responses, which are of length equal to
+ the sum of the length of the prompt and the chosen/rejected response, with
+ label_pad_token_id for the prompt tokens.
+ """
+ # import pdb; pdb.set_trace()
+ batch = {}
+
+ chosen_sources = make_conv(prompt, chosen)
+ rejected_sources = make_conv(prompt, rejected)
+ chosen_data_dict = dataset.preprocess([chosen_sources], self.tokenizer, has_image=True)
+ # chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ rejected_data_dict = dataset.preprocess([rejected_sources], self.tokenizer, has_image=True)
+ # rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()}
+ rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()}
+
+ for k, toks in {
+ "chosen": chosen_data_dict,
+ "rejected": rejected_data_dict,
+ }.items():
+ for type_key, tokens in toks.items():
+ if type_key == "token_type_ids":
+ continue
+ batch[f"{k}_{type_key}"] = tokens
+ return batch
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ tokenized_batch = []
+ Xs, keys = [], []
+ for feature in features:
+ prompt = feature["prompt"]
+ chosen = feature["chosen"]
+ rejected = feature["rejected"]
+
+ batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
+ batch_element["images"] = feature["images"]
+ tokenized_batch.append(batch_element)
+
+ # return collated batch
+ padded_batch = self.collate(tokenized_batch)
+ return padded_batch
+
+
+import json
+
+
+def load_jsonl(save_path):
+ with open(save_path) as f:
+ data = [json.loads(line) for line in f.readlines()]
+ return data
+
+
+def load_json(path):
+ with open(path) as f:
+ data = json.load(f)
+ return data
+
+
+def load_data(data_path):
+ if "jsonl" in data_path:
+ data_list = load_jsonl(data_path)
+ else:
+ data_list = load_json(data_path)
+ return data_list
+
+
+class DPODataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_mixture: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
+ super(Dataset, self).__init__()
+ data_path = datasets_mixture.DATASETS_LEGACY[data_mixture].data_path
+ list_data_dict = load_data(data_path)
+ # if data_args.num_sample is not None:
+ # list_data_dict = list_data_dict[:data_args.num_sample]
+
+ print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+ self.image_folder = datasets_mixture.DATASETS_LEGACY[data_mixture].image_path
+
+ def __len__(self):
+ # return 20
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ """
+ {
+ 'prompt': 'Is there a snowman wearing a green scarf and hat in the background?',
+ 'chosen': 'No, there is no snowman wearing a green scarf and hat in the background of the image. The image features a person ...',
+ 'rejected': 'No, there is no snowman in the background.',
+ 'image_path': '/mnt/bn/liangkeg/data/ruohongz/dpo_data/dpo_images/LRVInstruction-000000009569.jpg',
+ 'image_name': 'LRVInstruction-000000009569.jpg'
+ }
+ """
+ # sources = self.list_data_dict[i]
+ # if isinstance(i, int):
+ # sources = [sources]
+ # assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following
+
+ video_file = data_dict["video"] + ".mp4"
+ video_folder = self.image_folder
+ video_path = os.path.join(video_folder, video_file)
+ num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
+ loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0
+
+ fps = None
+ frame_count = None
+
+ images, frames_loaded = dataset.LazySupervisedDataset._load_video(
+ video_path, num_video_frames, loader_fps, self.data_args, fps=fps, frame_count=frame_count
+ )
+
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+
+ data_dict["images"] = image_tensor
+
+ prompt = data_dict["prompt"]
+ prompt = prompt.replace("", "").strip()
+ prompt = "\n" * frames_loaded + prompt
+ data_dict["prompt"] = prompt
+
+ return data_dict
+
+
+def train():
+ global local_rank
+
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # FIXME(zhijianl): This should be deprecated when we move to the new scripts.
+ if os.getenv("RUN_NAME") is not None:
+ training_args.run_name = os.getenv("RUN_NAME")
+ else:
+ training_args.run_name = training_args.output_dir.split("/")[-1]
+
+ if training_args.use_one_logger:
+ try:
+ from one_logger_utils.huggingface import TimeEventCallback, hook_trainer_cls
+ except ImportError as e:
+ logging.warning(
+ f"""one_logger_utils is not installed. Please install it to use one_logger.
+ Please install via `pip install --index-url=https://sc-hw-artf.nvidia.com/artifactory/api/pypi/hwinf-mlwfo-pypi/simple --upgrade one-logger-utils
+`"""
+ )
+ raise e
+ batch_size = os.environ.get("GLOBAL_TRAIN_BATCH_SIZE", 16)
+ app_tag = f"{training_args.run_name}_{training_args.model_max_length}_{batch_size}"
+ one_logger_callback_config = {
+ "enable_for_current_rank": os.environ.get("RANK") == "0",
+ "one_logger_async": True,
+ "one_logger_project": "vila",
+ "log_every_n_train_iterations": 10,
+ "app_tag_run_version": "0.0.0",
+ "summary_data_schema_version": "1.0.0",
+ "app_run_type": "training",
+ "app_tag": app_tag,
+ "app_tag_run_name": training_args.run_name,
+ "world_size": os.environ.get("WORLD_SIZE", -1),
+ "global_batch_size": batch_size,
+ "batch_size": batch_size,
+ "train_iterations_target": int(data_args.num_video_frames / batch_size),
+ "train_samples_target": data_args.num_video_frames,
+ "is_train_iterations_enabled": True,
+ "is_baseline_run": False,
+ "is_test_iterations_enabled": False,
+ "is_validation_iterations_enabled": False,
+ "is_save_checkpoint_enabled": True,
+ "is_log_throughput_enabled": False,
+ "micro_batch_size": os.environ.get("PER_DEVICE_TRAIN_BATCH_SIZE", 16),
+ "seq_length": training_args.model_max_length,
+ "save_checkpoint_strategy": "sync",
+ }
+ one_logger_callback_utils = TimeEventCallback(one_logger_callback_config)
+
+ local_rank = training_args.local_rank
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+
+ bnb_model_from_pretrained_args.update(
+ dict(
+ device_map={"": training_args.device},
+ # load_in_4bit=training_args.bits == 4,
+ # load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_skip_modules=["lm_head"],
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
+ ),
+ )
+ )
+
+ set_seed(training_args.seed)
+
+ sp_degree = training_args.seq_parallel_size
+ ring_degree = training_args.seq_parallel_ring_size
+ if sp_degree > 1:
+ set_pg_manager(sp_degree, ring_degree, ring_type=training_args.seq_parallel_ring_type)
+ print(f"Sequence parallelism is enabled, SP = {sp_degree}")
+
+ resume_path, continue_training = get_checkpoint_path(training_args.output_dir)
+
+ if not continue_training:
+ print(f"Models has been ready under {training_args.output_dir}. Skipp training")
+ exit(0)
+
+ if resume_path:
+ resume_from_checkpoint = True
+ if training_args.lora_enable:
+ model_cls = LlavaLlamaModel
+ config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)
+ config.resume_path = model_args.model_name_or_path
+ else:
+ config = AutoConfig.from_pretrained(resume_path, trust_remote_code=True)
+ config.resume_path = resume_path
+ model_cls = eval(config.architectures[0])
+ else:
+ ## first time training
+ resume_from_checkpoint = False
+ ## llm and default multimodal model
+ # if (
+ # model_args.quantize_model in quantize_args_to_model_class.keys()
+ # ): # However, qmem should not used currently becuase I haven't merge the memory reduction version into VILA
+ # from llava.model.language_model.qllava_qllama import QLlavaLlamaModel
+
+ # model_cls = QLlavaLlamaModel
+ # else:
+ assert (
+ model_args.quantize_model == "false"
+ ), f"{model_args.quantize_model} for model_args.quantize_model is not supported"
+ model_cls = LlavaLlamaModel
+ config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)
+
+ if getattr(config, "resume_path", None) is not None:
+ config.resume_path = model_args.model_name_or_path
+
+ ## extra configurations
+ prepare_config_for_training(config, model_args, training_args, data_args)
+
+ if training_args.use_one_logger:
+ one_logger_callback_utils.on_model_init_start()
+
+ # if model_args.quantize_model in quantize_args_to_model_class.keys():
+ # model = model_cls(
+ # config=config,
+ # model_args=model_args,
+ # attn_implementation="flash_attention_2",
+ # model_max_length=training_args.model_max_length,
+ # cache_dir=training_args.cache_dir,
+ # **bnb_model_from_pretrained_args,
+ # )
+ # else:
+ model = model_cls(
+ config=config,
+ attn_implementation="flash_attention_2",
+ model_max_length=training_args.model_max_length,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args,
+ )
+
+ if training_args.use_one_logger:
+ one_logger_callback_utils.on_model_init_end()
+
+ if not resume_path or training_args.lora_enable:
+ if model_args.mlp_path is not None:
+ state_dict = torch.load(model_args.mlp_path, map_location="cpu")
+ state_dict_new = {}
+ for k, v in state_dict.items():
+ if k == "0.weight":
+ state_dict_new["layers.1.weight"] = v
+ if k == "0.bias":
+ state_dict_new["layers.1.bias"] = v
+ if k == "1.weight":
+ state_dict_new["layers.2.weight"] = v
+ if k == "1.bias":
+ state_dict_new["layers.2.bias"] = v
+ if k == "3.weight":
+ state_dict_new["layers.4.weight"] = v
+ if k == "3.bias":
+ state_dict_new["layers.4.bias"] = v
+ model.get_mm_projector().load_state_dict(state_dict_new)
+
+ # This is an empty func.
+ # It would be overwritten by unit test script.
+ if unit_test_rope_scaling(model, model.llm.config, training_args):
+ return
+
+ # Take a look on model architecture.
+ mprint(model)
+
+ model.llm.config.use_cache = False
+
+ ## set tunnable parameters
+ # logging.warning(
+ # "You are setting tunable parameters for the model. Previous args include 'freeze_backbone' and 'tune_mm_mlp_adapter' are deprecated.\n Notice: default value of tune_xxx is False, which means you would not tune this part."
+ # )
+
+ def need_to_modify_do_sample(generation_config):
+ if generation_config is None:
+ warnings.warn("generation config is None, skip do sample modification")
+ return False
+ if generation_config.do_sample is False:
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ return True
+ if generation_config.top_p is not None and generation_config.top_p != 1.0:
+ return True
+ return False
+
+ if need_to_modify_do_sample(model.llm.generation_config):
+ model.llm.generation_config.do_sample = True
+
+ ## quantize training @yunhao: be careful here
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+
+ model.llm.config.torch_dtype = (
+ torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ )
+ model.llm = prepare_model_for_kbit_training(
+ model.llm, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model.llm, "enable_input_require_grads"):
+ model.llm.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, PeftModel, get_peft_model
+
+ lora_config = LoraConfig(
+ use_dora=training_args.use_dora,
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model, training_args.lora_llm, training_args.lora_vt, training_args.lora_st, training_args.lora_sot),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ if resume_from_checkpoint:
+ # load non-lora weights
+ if os.path.exists(os.path.join(resume_path, "non_lora_trainables.bin")):
+ non_lora_trainables = torch.load(
+ os.path.join(resume_path, "non_lora_trainables.bin"),
+ map_location="cpu",
+ )
+ non_lora_trainables = {
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
+ non_lora_trainables = {
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ mprint("Resume from checkpoint...", resume_path)
+ model = PeftModel.from_pretrained(model, resume_path, is_trainable=True)
+ else:
+ mprint("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+ mprint(model)
+ model.print_trainable_parameters()
+
+ # currently assume fft for mm projector
+ if training_args.lora_enable:
+ if not training_args.lora_llm:
+ model.get_llm().requires_grad_(training_args.tune_language_model)
+
+ if model.get_sound_tower():
+ if training_args.lora_sot:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_sound_tower().sound_tower.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+ elif training_args.tune_sound_tower:
+ model.get_sound_tower().requires_grad_(training_args.tune_sound_tower)
+ model.get_sound_mm_projector().requires_grad_(training_args.tune_sound_mm_projector)
+ mprint(f"sound mm projector {training_args.tune_sound_mm_projector}")
+ model.print_trainable_parameters()
+ else:
+ model.get_llm().requires_grad_(training_args.tune_language_model)
+ mprint(f"Tunable parameters:\nlanguage model {training_args.tune_language_model}")
+ model.get_sound_tower().requires_grad_(training_args.tune_sound_tower)
+ model.get_sound_mm_projector().requires_grad_(training_args.tune_sound_mm_projector)
+
+ mprint(f"sound tower {training_args.tune_sound_tower}")
+ mprint(f"sound mm projector {training_args.tune_sound_mm_projector}")
+ trainable_params, all_param = get_nb_trainable_parameters(model)
+ print(
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
+ )
+
+ if not any(
+ [training_args.tune_language_model, training_args.tune_vision_tower, training_args.tune_speech_tower, training_args.tune_sound_tower, training_args.tune_mm_projector, training_args.tune_speech_mm_projector, training_args.tune_sound_mm_projector]
+ ):
+ logging.warning("You are not tuning any part of the model. Please check if this is intended.")
+
+ # @yunhao: tokenizer instantiation is moved into build_llm
+ tokenizer = model.tokenizer
+
+ if tokenizer.bos_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(bos_token="[BOS]"),
+ tokenizer=tokenizer,
+ model=model.llm,
+ )
+
+ # @yunhao: may move this block into method "build_llm"
+ tokenizer.pad_token = tokenizer.unk_token
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model.llm,
+ )
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+
+ sound_tower = model.get_sound_tower()
+ data_args.is_multimodal = True
+
+ if sound_tower is not None:
+ model.config.sound_mm_projector_lr = training_args.sound_mm_projector_lr
+ model.config.sound_tower_lr = training_args.sound_tower_lr
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ assert not model_args.mm_use_im_patch_token
+
+ model.config.num_time_tokens = data_args.num_time_tokens = model_args.num_time_tokens
+ model.config.time_token_format = data_args.time_token_format = model_args.time_token_format
+ if model_args.num_time_tokens > 0:
+ time_tokens = [model.config.time_token_format.format(t=t) for t in range(model.config.num_time_tokens)]
+ num_new_tokens = tokenizer.add_tokens(time_tokens)
+ assert len(time_tokens) == num_new_tokens or num_new_tokens == 0
+ model.resize_token_embeddings(len(tokenizer))
+ model.config.time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens)
+ else:
+ model.config.time_token_ids = []
+ model.config.soft_ce_std = model_args.soft_ce_std
+
+ ## TODO pay attention to quantize
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if "norm" in name:
+ module = module.to(torch.float32)
+ if "lm_head" in name or "embed_tokens" in name:
+ if hasattr(module, "weight"):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_args.s2_scales = list(map(int, model_args.s2_scales.split(",")))
+ data_args.group_by_modality_length = training_args.group_by_modality_length
+ data_module = make_supervised_data_module(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ training_args=training_args,
+ )
+
+ # Add a training step_end callback to check whether to autosuspend.
+ callbacks = [AutoResumeCallback(), TimeoutTerminateCallback()]
+
+ if training_args.dpo:
+ ref_model = model_cls(
+ config=config,
+ attn_implementation="flash_attention_2",
+ model_max_length=training_args.model_max_length,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args,
+ )
+
+ train_dataset = DPODataset(tokenizer=tokenizer, data_mixture=data_args.data_mixture, data_args=data_args)
+
+ data_collator = DPODataCollator(
+ tokenizer=tokenizer,
+ label_pad_token_id=IGNORE_INDEX,
+ pad_token_id=tokenizer.pad_token_id,
+ )
+ extra_info = []
+ extra_info.append(len(train_dataset))
+ training_args.sample_lens = extra_info
+
+ trainer = VILADPOTrainer(
+ model=model,
+ dpo_alpha=1.0,
+ gamma=0,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ args=training_args,
+ beta=training_args.dpo_beta,
+ callbacks=callbacks,
+ train_dataset=train_dataset,
+ data_collator=data_collator,
+ )
+ else:
+ if training_args.use_one_logger:
+ newLLaVATrainer = hook_trainer_cls(LLaVATrainer, one_logger_callback_utils=one_logger_callback_utils)
+ trainer = newLLaVATrainer(
+ model=model, tokenizer=tokenizer, args=training_args, callbacks=callbacks, **data_module
+ )
+ else:
+ trainer = LLaVATrainer(
+ model=model, tokenizer=tokenizer, args=training_args, callbacks=callbacks, **data_module
+ )
+
+ if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]:
+ from llava.model.coat.fp8_trainer import CoatFP8Trainer
+
+ trainer._inner_training_loop = CoatFP8Trainer._inner_training_loop.__get__(
+ trainer, LLaVATrainer
+ ) # GPT told me to do this
+
+ print(
+ "length of dataloader:",
+ len(trainer.get_train_dataloader()),
+ len(trainer.train_dataset),
+ flush=True,
+ )
+ print(
+ "[GPU memory] before trainer",
+ torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ flush=True,
+ )
+
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+
+ if training_args.debug_e2e:
+ exit()
+
+ trainer.save_state()
+
+ model.llm.config.use_cache = True
+ model.config.resume_path = model.config._name_or_path = training_args.output_dir
+ ## TODO handle lora for new initialization
+ if training_args.lora_enable:
+ if training_args.use_one_logger:
+ one_logger_callback_utils.on_save_checkpoint_start(global_step=trainer.state.global_step)
+ state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(
+ non_lora_state_dict,
+ os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
+ )
+ if training_args.use_one_logger:
+ one_logger_callback_utils.on_save_checkpoint_success(global_step=trainer.state.global_step)
+ one_logger_callback_utils.on_save_checkpoint_end(global_step=trainer.state.global_step)
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+ if training_args.use_one_logger:
+ one_logger_callback_utils.on_app_end()
+
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/train/train_hybrid.py b/llava/train/train_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aef321ab810085ee3c5768155d105be76342478
--- /dev/null
+++ b/llava/train/train_hybrid.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+from unittest import mock
+
+from llava.model.utils.packing import _get_unpad_data
+from llava.train.sequence_parallel.monkey_patch import _flash_attention_forward, _update_causal_mask
+from llava.train.train import train
+from llava.train.transformer_normalize_monkey_patch import patched_normalize
+
+
+def __len__(self):
+ return len(self.batch_sampler)
+
+
+def __iter__(self):
+ return self.batch_sampler.__iter__()
+
+
+if __name__ == "__main__":
+ with (
+ mock.patch("transformers.models.llama.modeling_llama._flash_attention_forward", new=_flash_attention_forward),
+ mock.patch("transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask", new=_update_causal_mask),
+ mock.patch("transformers.models.qwen2.modeling_qwen2._flash_attention_forward", new=_flash_attention_forward),
+ mock.patch("transformers.models.qwen2.modeling_qwen2.Qwen2Model._update_causal_mask", new=_update_causal_mask),
+ mock.patch("transformers.image_processing_utils.normalize", new=patched_normalize),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__len__", new=__len__),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__iter__", new=__iter__),
+ ):
+ train()
diff --git a/llava/train/train_ln.py b/llava/train/train_ln.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f04a0fff113468a4a956c778bcc05ebb8ad56b0
--- /dev/null
+++ b/llava/train/train_ln.py
@@ -0,0 +1,870 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import os
+import warnings
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Sequence
+
+import torch
+import transformers
+from torch.utils.data import Dataset
+from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, LlamaForCausalLM, set_seed
+from transformers.modeling_utils import unwrap_model
+
+import llava.data.dataset as dataset
+import llava.data.datasets_mixture as datasets_mixture
+from llava import conversation as conversation_lib
+from llava.constants import (
+ DEFAULT_IM_END_TOKEN,
+ DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_TOKEN,
+ IGNORE_INDEX,
+ IMAGE_TOKEN_INDEX,
+)
+from llava.data import make_supervised_data_module
+from llava.mm_utils import process_image
+from llava.model import LlavaLlamaConfig, LlavaLlamaModel
+from llava.train.args import DataArguments, ModelArguments, TrainingArguments
+from llava.train.callbacks.autoresume_callback import AutoResumeCallback
+from llava.train.llava_trainer import LLaVATrainer, VILADPOTrainer
+from llava.train.sequence_parallel import set_pg_manager
+from llava.train.slurm_utils import TimeoutTerminateCallback
+from llava.train.utils import (
+ get_checkpoint_path,
+ mprint,
+ prepare_config_for_training,
+ unit_test_rope_scaling,
+ vision_resolution_elevation,
+)
+from llava.trl.trainer.utils import DPODataCollatorWithPadding
+
+local_rank = None
+
+if "WANDB_PROJECT" not in os.environ:
+ os.environ["WANDB_PROJECT"] = "AF3"
+
+
+def get_nb_trainable_parameters(model) -> tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and the number of all parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ num_params = param.ds_numel
+
+ # Due to the design of 4bit linear layers from bitsandbytes
+ # one needs to multiply the number of parameters by 2 to get
+ # the correct number of parameters
+ if param.__class__.__name__ == "Params4bit":
+ if hasattr(param, "element_size"):
+ num_bytes = param.element_size()
+ elif not hasattr(param, "quant_storage"):
+ num_bytes = 1
+ else:
+ num_bytes = param.quant_storage.itemsize
+ num_params = num_params * 2 * num_bytes
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+
+ return trainable_params, all_param
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model, lora_llm, lora_vt):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ["mm_projector", "vision_resampler"]
+ assert lora_llm or lora_vt, "Not applying LoRA to any of the modules..."
+
+ if not lora_llm:
+ multimodal_keywords += ["llm"]
+ if not lora_vt:
+ multimodal_keywords += ["vision_tower"]
+
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ if not "lm_head" in name:
+ lora_module_names.add(name)
+ # names = name.split(".")
+ # lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ # if "lm_head" in lora_module_names: # needed for 16-bit
+ # lora_module_names.remove("lm_head")
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir, _internal_call=True)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def make_conv(prompt, answer):
+ return [
+ {
+ "from": "human",
+ "value": prompt,
+ },
+ {
+ "from": "gpt",
+ "value": answer,
+ },
+ ]
+
+
+@dataclass
+class DPODataCollator(DPODataCollatorWithPadding):
+ tokenizer: Any = None
+
+ def collate(self, batch):
+ # first, pad everything to the same length
+ # input_ids, labels = tuple([instance[key] for instance in instances]
+ # for key in ("input_ids", "labels"))
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
+ # input_ids,
+ # batch_first=True,
+ # padding_value=self.tokenizer.pad_token_id)
+ # labels = torch.nn.utils.rnn.pad_sequence(labels,
+ # batch_first=True,
+ # padding_value=IGNORE_INDEX)
+ # input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ # labels = labels[:, :self.tokenizer.model_max_length]
+ # batch = dict(
+ # input_ids=input_ids,
+ # labels=labels,
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ # )
+ padded_batch = {}
+ for k in batch[0].keys():
+ if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
+ # if "prompt" in k:
+ # to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
+ # else:
+ to_pad = [torch.LongTensor(ex[k]) for ex in batch]
+ if k.endswith("_input_ids"):
+ padding_value = self.pad_token_id
+ elif k.endswith("_labels"):
+ padding_value = self.label_pad_token_id
+ else:
+ continue
+ # elif k.endswith("_attention_mask"):
+ # padding_value = self.padding_value
+ # else:
+ # raise ValueError(f"Unexpected key in batch '{k}'")
+
+ padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
+ # for the prompt, flip back so padding is on left side
+ # if "prompt" in k:
+ # padded_batch[k] = padded_batch[k].flip(dims=[1])
+ else:
+ padded_batch[k] = [ex[k] for ex in batch]
+ for k in ["chosen_input_ids", "rejected_input_ids"]:
+ attn_k = k.replace("input_ids", "attention_mask")
+ padded_batch[attn_k] = padded_batch[k].ne(self.pad_token_id)
+ return padded_batch
+
+ def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str) -> Dict:
+ """Tokenize a single batch element.
+
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
+
+ We also create the labels for the chosen/rejected responses, which are of length equal to
+ the sum of the length of the prompt and the chosen/rejected response, with
+ label_pad_token_id for the prompt tokens.
+ """
+ # import pdb; pdb.set_trace()
+ batch = {}
+
+ chosen_sources = make_conv(prompt, chosen)
+ rejected_sources = make_conv(prompt, rejected)
+ chosen_data_dict = dataset.preprocess([chosen_sources], self.tokenizer, has_image=True)
+ # chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ rejected_data_dict = dataset.preprocess([rejected_sources], self.tokenizer, has_image=True)
+ # rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()}
+ rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()}
+
+ for k, toks in {
+ "chosen": chosen_data_dict,
+ "rejected": rejected_data_dict,
+ }.items():
+ for type_key, tokens in toks.items():
+ if type_key == "token_type_ids":
+ continue
+ batch[f"{k}_{type_key}"] = tokens
+ return batch
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ tokenized_batch = []
+ Xs, keys = [], []
+ for feature in features:
+ prompt = feature["prompt"]
+ chosen = feature["chosen"]
+ rejected = feature["rejected"]
+
+ batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
+ batch_element["images"] = feature["images"]
+ tokenized_batch.append(batch_element)
+
+ # return collated batch
+ padded_batch = self.collate(tokenized_batch)
+ return padded_batch
+
+
+import json
+
+
+def load_jsonl(save_path):
+ with open(save_path) as f:
+ data = [json.loads(line) for line in f.readlines()]
+ return data
+
+
+def load_json(path):
+ with open(path) as f:
+ data = json.load(f)
+ return data
+
+
+def load_data(data_path):
+ if "jsonl" in data_path:
+ data_list = load_jsonl(data_path)
+ else:
+ data_list = load_json(data_path)
+ return data_list
+
+
+class DPODataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_mixture: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
+ super(Dataset, self).__init__()
+ data_path = datasets_mixture.DATASETS_LEGACY[data_mixture].data_path
+ list_data_dict = load_data(data_path)
+ # if data_args.num_sample is not None:
+ # list_data_dict = list_data_dict[:data_args.num_sample]
+
+ print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+ self.image_folder = datasets_mixture.DATASETS_LEGACY[data_mixture].image_path
+
+ def __len__(self):
+ # return 20
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ """
+ {
+ 'prompt': 'Is there a snowman wearing a green scarf and hat in the background?',
+ 'chosen': 'No, there is no snowman wearing a green scarf and hat in the background of the image. The image features a person ...',
+ 'rejected': 'No, there is no snowman in the background.',
+ 'image_path': '/mnt/bn/liangkeg/data/ruohongz/dpo_data/dpo_images/LRVInstruction-000000009569.jpg',
+ 'image_name': 'LRVInstruction-000000009569.jpg'
+ }
+ """
+ # sources = self.list_data_dict[i]
+ # if isinstance(i, int):
+ # sources = [sources]
+ # assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following
+
+ video_file = data_dict["video"] + ".mp4"
+ video_folder = self.image_folder
+ video_path = os.path.join(video_folder, video_file)
+ num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
+ loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0
+
+ fps = None
+ frame_count = None
+
+ images, frames_loaded = dataset.LazySupervisedDataset._load_video(
+ video_path, num_video_frames, loader_fps, self.data_args, fps=fps, frame_count=frame_count
+ )
+
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+
+ data_dict["images"] = image_tensor
+
+ prompt = data_dict["prompt"]
+ prompt = prompt.replace("", "").strip()
+ prompt = "\n" * frames_loaded + prompt
+ data_dict["prompt"] = prompt
+
+ return data_dict
+
+
+def train():
+ global local_rank
+
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # FIXME(zhijianl): This should be deprecated when we move to the new scripts.
+ if os.getenv("RUN_NAME") is not None:
+ training_args.run_name = os.getenv("RUN_NAME")
+ else:
+ training_args.run_name = training_args.output_dir.split("/")[-1]
+
+ local_rank = training_args.local_rank
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+
+ bnb_model_from_pretrained_args.update(
+ dict(
+ device_map={"": training_args.device},
+ # load_in_4bit=training_args.bits == 4,
+ # load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_skip_modules=["lm_head"],
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
+ ),
+ )
+ )
+
+ set_seed(training_args.seed)
+
+ sp_degree = training_args.seq_parallel_size
+ ring_degree = training_args.seq_parallel_ring_size
+ if sp_degree > 1:
+ set_pg_manager(sp_degree, ring_degree, ring_type=training_args.seq_parallel_ring_type)
+ print(f"Sequence parallelism is enabled, SP = {sp_degree}")
+
+ resume_path, continue_training = get_checkpoint_path(training_args.output_dir)
+
+ if not continue_training:
+ print(f"Models has been ready under {training_args.output_dir}. Skipp training")
+ exit(0)
+
+ if resume_path:
+ resume_from_checkpoint = True
+ if training_args.lora_enable:
+ model_cls = LlavaLlamaModel
+ config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)
+ config.resume_path = model_args.model_name_or_path
+ else:
+ config = AutoConfig.from_pretrained(resume_path, trust_remote_code=True)
+ config.resume_path = resume_path
+ model_cls = eval(config.architectures[0])
+ else:
+ ## first time training
+ resume_from_checkpoint = False
+ ## llm and default multimodal model
+ if model_args.quantize_model.lower() in [
+ "qlinear",
+ "te_qlinear",
+ "qmem",
+ ]: # However, qmem should not used currently becuase I haven't merge the memory reduction version into VILA
+ from functools import partial
+
+ from llava.model.language_model.qllava_qllama import QLlavaLlamaModel
+
+ model_cls = QLlavaLlamaModel
+ else:
+ assert (
+ model_args.quantize_model.lower() == "false"
+ ), f"{model_args.quantize_model.lower()} for model_args.quantize_model is not supported"
+ model_cls = LlavaLlamaModel
+ config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)
+
+ if getattr(config, "resume_path", None) is not None:
+ config.resume_path = model_args.model_name_or_path
+
+ ## extra configurations
+ prepare_config_for_training(config, model_args, training_args, data_args)
+ if model_args.quantize_model.lower() in ["qlinear", "te_qlinear", "qmem"]:
+ model = model_cls(
+ config=config,
+ model_args=model_args,
+ attn_implementation="flash_attention_2",
+ model_max_length=training_args.model_max_length,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args,
+ )
+ else:
+ model = model_cls(
+ config=config,
+ attn_implementation="flash_attention_2",
+ model_max_length=training_args.model_max_length,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args,
+ )
+
+ if not resume_path or training_args.lora_enable:
+ if model_args.mlp_path is not None:
+ state_dict = torch.load(model_args.mlp_path, map_location="cpu")
+ state_dict_new = {}
+ for k, v in state_dict.items():
+ if k == "0.weight":
+ state_dict_new["layers.1.weight"] = v
+ if k == "0.bias":
+ state_dict_new["layers.1.bias"] = v
+ if k == "1.weight":
+ state_dict_new["layers.2.weight"] = v
+ if k == "1.bias":
+ state_dict_new["layers.2.bias"] = v
+ if k == "3.weight":
+ state_dict_new["layers.4.weight"] = v
+ if k == "3.bias":
+ state_dict_new["layers.4.bias"] = v
+ model.get_mm_projector().load_state_dict(state_dict_new)
+
+ vision_resolution_elevation(model, config)
+ # This is an empty func.
+ # It would be overwritten by unit test script.
+ if unit_test_rope_scaling(model, model.llm.config, training_args):
+ return
+
+ # Take a look on model architecture.
+ mprint(model)
+
+ model.llm.config.use_cache = False
+
+ ## set tunnable parameters
+ logging.warning(
+ "You are setting tunable parameters for the model. Previous args include 'freeze_backbone' and 'tune_mm_mlp_adapter' are deprecated.\n Notice: default value of tune_xxx is False, which means you would not tune this part."
+ )
+
+ def need_to_modify_do_sample(generation_config):
+ if generation_config is None:
+ warnings.warn("generation config is None, skip do sample modification")
+ return False
+ if generation_config.do_sample is False:
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ return True
+ if generation_config.top_p is not None and generation_config.top_p != 1.0:
+ return True
+ return False
+
+ if need_to_modify_do_sample(model.llm.generation_config):
+ model.llm.generation_config.do_sample = True
+
+ ## quantize training @yunhao: be careful here
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+
+ model.llm.config.torch_dtype = (
+ torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ )
+ model.llm = prepare_model_for_kbit_training(
+ model.llm, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model.llm, "enable_input_require_grads"):
+ model.llm.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, PeftModel, get_peft_model
+
+ lora_config = LoraConfig(
+ use_dora=training_args.use_dora,
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model, training_args.lora_llm, training_args.lora_vt),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ if resume_from_checkpoint:
+ # load non-lora weights
+ if os.path.exists(os.path.join(resume_path, "non_lora_trainables.bin")):
+ non_lora_trainables = torch.load(
+ os.path.join(resume_path, "non_lora_trainables.bin"),
+ map_location="cpu",
+ )
+ non_lora_trainables = {
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
+ non_lora_trainables = {
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ mprint("Resume from checkpoint...", resume_path)
+ model = PeftModel.from_pretrained(model, resume_path, is_trainable=True)
+ else:
+ mprint("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+ mprint(model)
+ model.print_trainable_parameters()
+
+ # currently assume fft for mm projector
+ if training_args.lora_enable:
+ if not training_args.lora_llm:
+ model.get_llm().requires_grad_(training_args.tune_language_model)
+ if model.get_vision_tower():
+ if training_args.lora_vt:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_vision_tower().vision_tower.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+ elif training_args.tune_vision_tower:
+ model.get_vision_tower().requires_grad_(training_args.tune_vision_tower)
+ close_modules = ["embedding", "mlp", "self_attn", "head"]
+ for name, param in model.get_vision_tower().named_parameters():
+ if any(f"{module}" in name for module in close_modules):
+ print(f"freeze {name}")
+ param.requires_grad = False
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ if "embedding" in close_modules:
+ model.get_vision_tower().vision_tower.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+ model.get_mm_projector().requires_grad_(training_args.tune_mm_projector)
+ mprint(f"mm projector {training_args.tune_mm_projector}")
+ model.print_trainable_parameters()
+ else:
+ model.get_llm().requires_grad_(training_args.tune_language_model)
+ mprint(f"Tunable parameters:\nlanguage model {training_args.tune_language_model}")
+ if model.get_vision_tower():
+ model.get_vision_tower().requires_grad_(training_args.tune_vision_tower)
+ close_modules = ["embedding", "mlp", "self_attn", "head"]
+ for name, param in model.get_vision_tower().named_parameters():
+ if any(f"{module}" in name for module in close_modules):
+ print(f"freeze {name}")
+ param.requires_grad = False
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ if "embedding" in close_modules:
+ model.get_vision_tower().vision_tower.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+ model.get_mm_projector().requires_grad_(training_args.tune_mm_projector)
+ mprint(f"vision tower {training_args.tune_vision_tower}")
+ mprint(f"mm projector {training_args.tune_mm_projector}")
+ trainable_params, all_param = get_nb_trainable_parameters(model)
+ print(
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
+ )
+
+ if not any(
+ [training_args.tune_language_model, training_args.tune_vision_tower, training_args.tune_mm_projector]
+ ):
+ logging.warning("You are not tuning any part of the model. Please check if this is intended.")
+
+ # @yunhao: tokenizer instantiation is moved into build_llm
+ tokenizer = model.tokenizer
+
+ if tokenizer.bos_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(bos_token="[BOS]"),
+ tokenizer=tokenizer,
+ model=model.llm,
+ )
+
+ # @yunhao: may move this block into method "build_llm"
+ tokenizer.pad_token = tokenizer.unk_token
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model.llm,
+ )
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ # kentang-mit@: It will be useful in on-the-fly packing
+ model.llm.pad_token_id = tokenizer.pad_token_id
+ model.llm.config.tokenizer_padding_side = tokenizer.padding_side
+ model.llm.config.tokenizer_model_max_length = tokenizer.model_max_length
+ if training_args.lora_enable:
+ model.base_model.model.llm.pad_token_id = tokenizer.pad_token_id
+
+ vision_tower = model.get_vision_tower()
+ if vision_tower is not None:
+ data_args.image_processor = vision_tower.image_processor
+ data_args.is_multimodal = True
+
+ if hasattr(data_args, "num_video_frames") and data_args.num_video_frames != None:
+ model.config.num_video_frames = data_args.num_video_frames
+ else:
+ model.config.num_video_frames = 8
+
+ if hasattr(data_args, "fps"):
+ model.config.fps = data_args.fps
+ else:
+ model.config.fps = 0.0
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ model.config.vision_tower_lr = training_args.vision_tower_lr
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ assert not model_args.mm_use_im_patch_token
+
+ model.config.num_time_tokens = data_args.num_time_tokens = model_args.num_time_tokens
+ model.config.time_token_format = data_args.time_token_format = model_args.time_token_format
+ if model_args.num_time_tokens > 0:
+ time_tokens = [model.config.time_token_format.format(t=t) for t in range(model.config.num_time_tokens)]
+ num_new_tokens = tokenizer.add_tokens(time_tokens)
+ assert len(time_tokens) == num_new_tokens or num_new_tokens == 0
+ model.resize_token_embeddings(len(tokenizer))
+ model.config.time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens)
+ else:
+ model.config.time_token_ids = []
+ model.config.soft_ce_std = model_args.soft_ce_std
+
+ num_patches = model.get_vision_tower().num_patches
+ downsample_rate = model.get_mm_projector().downsample_rate
+ num_image_tokens = math.ceil(num_patches**0.5 / downsample_rate) ** 2
+ data_args.num_image_tokens = num_image_tokens
+
+ ## TODO pay attention to quantize
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if "norm" in name:
+ module = module.to(torch.float32)
+ if "lm_head" in name or "embed_tokens" in name:
+ if hasattr(module, "weight"):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ training_args=training_args,
+ )
+
+ # Add a training step_end callback to check whether to autosuspend.
+ callbacks = [AutoResumeCallback(), TimeoutTerminateCallback()]
+
+ if training_args.dpo:
+ ref_model = model_cls(
+ config=config,
+ attn_implementation="flash_attention_2",
+ model_max_length=training_args.model_max_length,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args,
+ )
+
+ train_dataset = DPODataset(tokenizer=tokenizer, data_mixture=data_args.data_mixture, data_args=data_args)
+
+ data_collator = DPODataCollator(
+ tokenizer=tokenizer,
+ label_pad_token_id=IGNORE_INDEX,
+ pad_token_id=tokenizer.pad_token_id,
+ )
+ extra_info = []
+ extra_info.append(len(train_dataset))
+ training_args.sample_lens = extra_info
+
+ trainer = VILADPOTrainer(
+ model=model,
+ dpo_alpha=1.0,
+ gamma=0,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ args=training_args,
+ beta=training_args.dpo_beta,
+ callbacks=callbacks,
+ train_dataset=train_dataset,
+ data_collator=data_collator,
+ )
+ else:
+ trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, callbacks=callbacks, **data_module)
+ print(
+ "length of dataloader:",
+ len(trainer.get_train_dataloader()),
+ len(trainer.train_dataset),
+ flush=True,
+ )
+ print(
+ "[GPU memory] before trainer",
+ torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ flush=True,
+ )
+
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+
+ if training_args.debug_e2e:
+ exit()
+
+ trainer.save_state()
+
+ model.llm.config.use_cache = True
+ model.config.resume_path = model.config._name_or_path = training_args.output_dir
+ ## TODO handle lora for new initialization
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(
+ non_lora_state_dict,
+ os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
+ )
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/train/train_mem.py b/llava/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..c97c1cae635f34d863d3fe5afc23d360ef6acf4b
--- /dev/null
+++ b/llava/train/train_mem.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+from unittest import mock
+
+from llava.train.slurm_utils import set_timer
+from llava.train.train import train
+from llava.train.transformer_normalize_monkey_patch import (
+ _save_checkpoint,
+ compute_loss,
+ patched_normalize,
+ training_step,
+)
+import os
+os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
+def __len__(self):
+ return len(self.batch_sampler)
+
+
+def __iter__(self):
+ return self.batch_sampler.__iter__()
+
+
+if __name__ == "__main__":
+ with (
+ mock.patch("transformers.image_processing_utils.normalize", new=patched_normalize),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__len__", new=__len__),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__iter__", new=__iter__),
+ mock.patch("transformers.trainer.Trainer._save_checkpoint", new=_save_checkpoint),
+ mock.patch("transformers.trainer.Trainer.compute_loss", new=compute_loss),
+ mock.patch("transformers.trainer.Trainer.training_step", new=training_step),
+ ):
+ set_timer()
+ train()
diff --git a/llava/train/train_mem_ln.py b/llava/train/train_mem_ln.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa35701559be8cd17f11290570cda68c7383d37
--- /dev/null
+++ b/llava/train/train_mem_ln.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+from unittest import mock
+
+from llava.train.slurm_utils import set_timer
+from llava.train.train_ln import train
+from llava.train.transformer_normalize_monkey_patch import (
+ _save_checkpoint,
+ compute_loss,
+ patched_normalize,
+ training_step,
+)
+
+
+def __len__(self):
+ return len(self.batch_sampler)
+
+
+def __iter__(self):
+ return self.batch_sampler.__iter__()
+
+
+if __name__ == "__main__":
+ with (
+ mock.patch("transformers.image_processing_utils.normalize", new=patched_normalize),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__len__", new=__len__),
+ mock.patch("accelerate.data_loader.BatchSamplerShard.__iter__", new=__iter__),
+ mock.patch("transformers.trainer.Trainer._save_checkpoint", new=_save_checkpoint),
+ mock.patch("transformers.trainer.Trainer.compute_loss", new=compute_loss),
+ mock.patch("transformers.trainer.Trainer.training_step", new=training_step),
+ ):
+ set_timer()
+ train()
diff --git a/llava/train/transformer_normalize_monkey_patch.py b/llava/train/transformer_normalize_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..8afea9d06c62984f82e0a6b3aa63563e2c8cb453
--- /dev/null
+++ b/llava/train/transformer_normalize_monkey_patch.py
@@ -0,0 +1,303 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import transformers
+from transformers.image_transforms import (
+ ChannelDimension,
+ Iterable,
+ Optional,
+ Union,
+ get_channel_dimension_axis,
+ infer_channel_dimension_format,
+ np,
+ to_channel_dimension_format,
+)
+
+
+def patched_normalize(
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+
+ image = (image - mean) / std
+
+ Args:
+ image (`np.ndarray`):
+ The image to normalize.
+ mean (`float` or `Iterable[float]`):
+ The mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ The standard deviation to use for normalization.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
+ """
+ if not isinstance(image, np.ndarray):
+ raise ValueError("image must be a numpy array")
+
+ input_data_format = infer_channel_dimension_format(image)
+ channel_axis = get_channel_dimension_axis(image)
+ num_channels = image.shape[channel_axis]
+
+ if isinstance(mean, Iterable):
+ if len(mean) != num_channels:
+ if num_channels == 1:
+ num_channels = 3
+ image = np.concatenate([image, image, image], axis=channel_axis)
+ else:
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
+ else:
+ mean = [mean] * num_channels
+ mean = np.array(mean, dtype=image.dtype)
+
+ if isinstance(std, Iterable):
+ if len(std) != num_channels:
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
+ else:
+ std = [std] * num_channels
+ std = np.array(std, dtype=image.dtype)
+
+ if input_data_format == ChannelDimension.LAST:
+ image = (image - mean) / std
+ else:
+ image = ((image.T - mean) / std).T
+
+ image = to_channel_dimension_format(image, data_format) if data_format is not None else image
+ return image
+
+
+def patch_normalize_preprocess():
+ transformers.image_transforms.normalize = patched_normalize
+
+
+import os
+
+import torch
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+from transformers.utils import logging
+
+TRAINER_STATE_NAME = "trainer_state.json"
+logger = logging.get_logger(__name__)
+
+
+def _save_checkpoint(self, model, trial, metrics=None):
+ # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+ # want to save except FullyShardedDDP.
+ # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+ # Save model checkpoint
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0:
+ logger.warning(
+ f"Checkpoint destination directory {output_dir} already exists and is non-empty."
+ "Saving will proceed but saved results may be invalid."
+ )
+ staging_output_dir = output_dir
+ else:
+ staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}")
+
+ self.save_model(staging_output_dir, _internal_call=True)
+
+ if not self.args.save_only_model:
+ # Save optimizer and scheduler
+ self._save_optimizer_and_scheduler(staging_output_dir)
+ # Save RNG state
+ self._save_rng_state(staging_output_dir)
+
+ # Determine the new best metric / best model checkpoint
+ if metrics is not None and self.args.metric_for_best_model is not None:
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ metric_value = metrics[metric_to_check]
+
+ operator = np.greater if self.args.greater_is_better else np.less
+ if (
+ self.state.best_metric is None
+ or self.state.best_model_checkpoint is None
+ or operator(metric_value, self.state.best_metric)
+ ):
+ self.state.best_metric = metric_value
+ self.state.best_model_checkpoint = staging_output_dir
+
+ # Save the Trainer state
+ if self.args.should_save:
+ self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME))
+
+ if self.args.push_to_hub:
+ self._push_from_checkpoint(staging_output_dir)
+
+ torch.distributed.barrier()
+ if staging_output_dir != output_dir:
+ with self.args.main_process_first(
+ desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
+ ):
+ if os.path.exists(staging_output_dir):
+ os.rename(staging_output_dir, output_dir)
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ # Solely rely on numerical checkpoint id for rotation.
+ # mtime is not reliable especially on some fuse fs in cloud environments.
+ self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
+
+
+from typing import Any, Dict, Union
+
+from torch import nn
+from transformers.training_args import OptimizerNames
+from transformers.utils import (
+ is_sagemaker_mp_enabled,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_npu_available,
+ is_torch_xpu_available,
+)
+
+
+def training_step(
+ self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
+) -> torch.Tensor:
+ """
+ Perform a training step on a batch of inputs.
+ Subclass and override to inject custom behavior.
+ Args:
+ model (`nn.Module`):
+ The model to train.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+ Return:
+ `torch.Tensor`: The tensor with training loss on this batch.
+ """
+ model.train()
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
+ self.optimizer.train()
+
+ inputs = self._prepare_inputs(inputs)
+ if is_sagemaker_mp_enabled():
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+ return loss_mb.reduce_mean().detach().to(self.args.device)
+
+ with self.compute_loss_context_manager():
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+
+ del inputs
+ if (
+ self.args.torch_empty_cache_steps is not None
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
+ ):
+ if is_torch_xpu_available():
+ torch.xpu.empty_cache()
+ elif is_torch_mlu_available():
+ torch.mlu.empty_cache()
+ elif is_torch_musa_available():
+ torch.musa.empty_cache()
+ elif is_torch_npu_available():
+ torch.npu.empty_cache()
+ elif is_torch_mps_available(min_version="2.0"):
+ torch.mps.empty_cache()
+ else:
+ torch.cuda.empty_cache()
+
+ kwargs = {}
+
+ # For LOMO optimizers you need to explicitly use the learnign rate
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ kwargs["learning_rate"] = self._get_learning_rate()
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
+
+ if self.use_apex:
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ if num_items_in_batch is not None:
+ if self.compute_loss_func or self.model_accepts_loss_kwargs:
+ loss *= self.args.gradient_accumulation_steps
+ # Average tokens across devices is orthogonal to gradient accumulation
+ loss *= self.args.world_size
+ self.accelerator.backward(loss, **kwargs)
+
+ return loss.detach() / self.args.gradient_accumulation_steps
+
+
+def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ """
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
+ Subclass and override for custom behavior.
+ """
+ if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
+ labels = inputs.pop("labels")
+ else:
+ labels = None
+ if num_items_in_batch is not None:
+ num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
+ num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
+ if self.model_accepts_loss_kwargs:
+ loss_kwargs = {}
+ if num_items_in_batch is not None:
+ loss_kwargs["num_items_in_batch"] = num_items_in_batch
+ inputs = {**inputs, **loss_kwargs}
+ outputs = model(**inputs)
+ # Save past state if it exists
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index]
+
+ if labels is not None:
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ if _is_peft_model(unwrapped_model):
+ model_name = unwrapped_model.base_model.model._get_name()
+ else:
+ model_name = unwrapped_model._get_name()
+ # User-defined compute_loss function
+ if self.compute_loss_func is not None:
+ loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
+ elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
+ else:
+ if isinstance(outputs, dict) and "loss" not in outputs:
+ raise ValueError(
+ "The model did not return a loss from the inputs, only the following keys: "
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+ )
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+ return (loss, outputs) if return_outputs else loss
diff --git a/llava/train/utils.py b/llava/train/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c2db7328c3c27e8cdcb21ce207571285cab121
--- /dev/null
+++ b/llava/train/utils.py
@@ -0,0 +1,270 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+import json
+import os
+import pathlib
+import re
+import warnings
+from dataclasses import dataclass
+
+import torch
+import torch.distributed as dist
+from accelerate.hooks import add_hook_to_module
+from transformers import PretrainedConfig, PreTrainedModel
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+
+from llava.train.sequence_parallel.globals import get_pg_manager, get_ulysses_sp_pg
+
+
+def rprint(*args, **kwargs):
+ rank = int(os.environ.get("RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ if world_size > 1 and dist.is_initialized():
+ return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
+ else:
+ return print(*args, **kwargs)
+
+
+def mprint(*args, **kwargs):
+ rank = int(os.environ.get("RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ if world_size > 1 and dist.is_initialized():
+ if rank == 0:
+ return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
+ else:
+ return
+ else:
+ return print(*args, **kwargs)
+
+
+def is_local(model_name_or_path: str) -> bool:
+ return os.path.isdir(model_name_or_path)
+
+
+def get_checkpoint_path(output_dir: str, checkpoint_prefix: str = "checkpoint") -> str | None:
+ output_dir = os.path.abspath(output_dir)
+ pathlib_dir = pathlib.Path(output_dir)
+
+ if list(pathlib_dir.glob("config.json")):
+ # training has been finished
+ return output_dir, False
+ else:
+ try:
+ ordering_and_checkpoint_path = []
+ glob_checkpoints = [
+ str(x) for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)
+ ]
+ for path in glob_checkpoints:
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+ if regex_match is not None and regex_match.groups() is not None:
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+ return checkpoints_sorted[-1][1], True
+ except:
+ return None, True
+
+
+def prepare_config_for_training(
+ config: PretrainedConfig, model_args: dataclass, training_args: dataclass, data_args: dataclass
+) -> None:
+ config.chat_template = model_args.chat_template
+ assert model_args.vision_tower is not None, "requires vision tower"
+ assert model_args.speech_tower is not None, "requires speech tower"
+ assert model_args.sound_tower is not None, "requires sound tower"
+ # set module configurations
+ if getattr(config, "llm_cfg", None) is None:
+ config.llm_cfg = model_args.model_name_or_path
+ if getattr(config, "vision_tower_cfg", None) is None:
+ config.vision_tower_cfg = model_args.vision_tower
+ if getattr(config, "speech_tower_cfg", None) is None:
+ config.speech_tower_cfg = model_args.speech_tower
+ if getattr(config, "sound_tower_cfg", None) is None:
+ config.sound_tower_cfg = model_args.sound_tower
+ if getattr(config, "mm_projector_cfg", None) is None:
+ config.mm_projector_cfg = model_args.mm_projector
+ if getattr(config, "speech_mm_projector_cfg", None) is None:
+ config.speech_mm_projector_cfg = model_args.speech_mm_projector
+ if getattr(config, "sound_mm_projector_cfg", None) is None:
+ config.sound_mm_projector_cfg = model_args.sound_mm_projector
+ # set default dtype
+ config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16
+ config.model_dtype = config.model_dtype.__str__()
+ # set tuning modules
+ config.tune_language_model = training_args.tune_language_model
+ config.tune_vision_tower = training_args.tune_vision_tower
+ config.tune_speech_tower = training_args.tune_speech_tower
+ config.tune_sound_tower = training_args.tune_sound_tower
+ config.tune_mm_projector = training_args.tune_mm_projector
+ config.tune_speech_mm_projector = training_args.tune_speech_mm_projector
+ config.tune_sound_mm_projector = training_args.tune_sound_mm_projector
+ # set data args
+ # Get the image_aspect_ratio from the config if is defined there
+ # (case of resuming from a checkpoint) or from the data_args
+ # (i.e. from the command line when starting a new training).
+ if getattr(data_args, "image_aspect_ratio", None) is not None:
+ if getattr(config, "image_aspect_ratio", None) is None:
+ config.image_aspect_ratio = data_args.image_aspect_ratio
+ elif getattr(config, "image_aspect_ratio", None) is not None:
+ data_args.image_aspect_ratio = config.image_aspect_ratio
+ else:
+ raise ValueError("image_aspect_ratio must be set either in data_args or in the pretrained config")
+
+ if (
+ hasattr(training_args, "deepspeed")
+ and training_args.deepspeed is not None
+ and "mics" in training_args.deepspeed
+ ):
+ config.deepspeed = training_args.deepspeed
+
+ for key, value in model_args.__dict__.items():
+ try:
+ value = json.loads(value)
+ except:
+ pass
+ setattr(config, key, value)
+
+
+def vision_resolution_elevation(model: PreTrainedModel, config: PretrainedConfig):
+ vision_tower = model.get_vision_tower()
+ if vision_tower is not None and "radio" not in vision_tower.__class__.__name__.lower():
+ vision_tower._maybe_resize_pos_embeds(
+ model=vision_tower.vision_tower,
+ image_processor=vision_tower.image_processor,
+ resolution=getattr(config, "vision_resolution", -1),
+ interpolate_mode=getattr(config, "interpolate_mode", "linear"),
+ )
+
+
+def unit_test_rope_scaling(model: PreTrainedModel, config: PretrainedConfig, training_args: dataclass):
+ return False
+
+
+def calculate_loss_weight(labels, ignore_index=-100):
+ # (Qinghao): Weighted loss based on num_active_elements
+ # To achieve accurate sequence parallel loss calculation, we need to get
+ # the real active_elements of each sequence partitions.
+ # For data parallelism, the loss almost remains the same (also more accurate).
+ shift_labels = labels[..., 1:].contiguous()
+ shift_labels = shift_labels.view(-1)
+
+ padding_mask = shift_labels.eq(ignore_index) # IGNORE_INDEX = -100 by default
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
+
+ # global_active_sum = copy.deepcopy(num_active_elements)
+ global_active_sum = num_active_elements.detach().clone()
+
+ dist.all_reduce(global_active_sum)
+ loss_weight = num_active_elements / global_active_sum * dist.get_world_size()
+ return loss_weight
+
+
+def reshard_hiddne_states_and_labels(hidden_states, labels):
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
+ sp_group = PROCESS_GROUP_MANAGER.ulysses_pg
+ from llava.constants import IGNORE_INDEX
+
+ # Get the seq len on different sp ranks
+ bs, shard_seqlen = labels.shape
+ ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=labels.device) for _ in range(sp_degree)]
+ dist.barrier(group=sp_group)
+ dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=labels.device), group=sp_group)
+ dist.barrier(group=sp_group)
+ global_seq_len = torch.cat(ulysses_seq_len, dim=0)
+ # Gather all labels and flaten them
+ all_labels = [
+ torch.zeros(bs, seq_len, dtype=labels.dtype, device=labels.device).contiguous() for seq_len in ulysses_seq_len
+ ]
+ dist.all_gather(all_labels, labels.contiguous(), group=sp_group)
+ # flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].view(-1)
+ flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].contiguous().view(-1)
+ # Get the label!=IGNORE_INDEX's index
+ flatten_label_mask = flatten_global_labels.ne(IGNORE_INDEX)
+ flatten_effective_label_index = flatten_label_mask.nonzero(as_tuple=True)
+ # padding the effective_label_index if the length is smaller than sp_degree
+ if flatten_effective_label_index[0].shape[0] < sp_degree:
+ warnings.warn(
+ f"The effective label length {flatten_effective_label_index[0].shape[0]} is smaller than sp_degree {sp_degree}, padding the index"
+ )
+ repeat_num = sp_degree // flatten_effective_label_index[0].shape[0] + 1
+ else:
+ repeat_num = 1
+ # Reconstruct the labels by selecting from the global labels
+ effective_global_labels = flatten_global_labels[flatten_effective_label_index]
+ if repeat_num > 1:
+ effective_global_labels = effective_global_labels.repeat(repeat_num)
+ # Global effective seqence length
+ global_effective_seq_len = effective_global_labels.shape[0]
+ reshard_size = global_effective_seq_len // sp_degree
+ # Hyper parameters to reshard the hidden states and labels
+ if sp_rank == 0:
+ original_start_id = 0
+ original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
+ start_id = 0
+ end_id = reshard_size * (sp_rank + 1)
+ elif sp_rank == sp_degree - 1:
+ original_start_id = torch.sum(global_seq_len[:sp_rank]).item()
+ original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
+ start_id = reshard_size * sp_rank
+ end_id = global_effective_seq_len
+ else:
+ original_start_id = torch.sum(global_seq_len[:sp_rank]).item()
+ original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
+ start_id = reshard_size * sp_rank
+ end_id = reshard_size * (sp_rank + 1)
+ # Get the local labels
+ effective_local_labels = torch.narrow(effective_global_labels, 0, start_id, end_id - start_id)
+ # Gather all hidden states and flaten them
+ # all_hidden_states = [torch.zeros(bs, seq_len, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=True).contiguous() for seq_len in ulysses_seq_len]
+ all_hidden_states = torch.zeros(
+ bs, torch.sum(global_seq_len), hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device
+ ).contiguous()
+ all_hidden_states[:, original_start_id:original_end_id, :] += hidden_states
+ dist.barrier(group=sp_group)
+ dist.all_reduce(all_hidden_states, group=sp_group)
+ dist.barrier(group=sp_group)
+ flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1])
+ # Get the local hidden states
+ effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index]
+ if repeat_num > 1:
+ effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1)
+ effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id)
+
+ return effective_local_hidden_states, effective_local_labels
+
+
+def sp_loss_rescale(shift_labels, loss):
+ from llava.constants import IGNORE_INDEX
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ labels_mask = shift_labels.ne(IGNORE_INDEX) # IGNORE_INDEX = -100 by default
+ num_active_elements = torch.sum(labels_mask)
+ global_active_sum = copy.deepcopy(num_active_elements)
+ # dist.barrier(group=get_ulysses_sp_pg())
+ dist.all_reduce(global_active_sum, group=get_ulysses_sp_pg())
+ # print(loss.shape, num_active_elements.shape, global_active_sum.shape)
+ loss = loss * num_active_elements / global_active_sum
+ dist.all_reduce(loss, group=get_ulysses_sp_pg())
+ return loss
diff --git a/llava/utils/__init__.py b/llava/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76ccca7e49181c27dce49da4f49de8968034d63c
--- /dev/null
+++ b/llava/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .utils import *
diff --git a/llava/utils/distributed.py b/llava/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..16aa9663870d0608156d898a72e5b487a23e3c78
--- /dev/null
+++ b/llava/utils/distributed.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+import warnings
+from typing import Any, List, Optional
+
+from torch import distributed as dist
+
+__all__ = [
+ "init",
+ "is_initialized",
+ "size",
+ "rank",
+ "local_size",
+ "local_rank",
+ "is_main",
+ "barrier",
+ "gather",
+ "all_gather",
+]
+
+
+def init() -> None:
+ if "RANK" not in os.environ:
+ warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
+ return
+ dist.init_process_group(backend="nccl", init_method="env://")
+
+
+def is_initialized() -> bool:
+ return dist.is_initialized()
+
+
+def size() -> int:
+ return int(os.environ.get("WORLD_SIZE", 1))
+
+
+def rank() -> int:
+ return int(os.environ.get("RANK", 0))
+
+
+def local_size() -> int:
+ return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
+
+
+def local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+
+def is_main() -> bool:
+ return rank() == 0
+
+
+def barrier() -> None:
+ dist.barrier()
+
+
+def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
+ if not is_initialized():
+ return [obj]
+ if is_main():
+ objs = [None for _ in range(size())]
+ dist.gather_object(obj, objs, dst=dst)
+ return objs
+ else:
+ dist.gather_object(obj, dst=dst)
+ return None
+
+
+def all_gather(obj: Any) -> List[Any]:
+ if not is_initialized():
+ return [obj]
+ objs = [None for _ in range(size())]
+ dist.all_gather_object(objs, obj)
+ return objs
diff --git a/llava/utils/io.py b/llava/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72b1a72365d600d060626e8528806910b8297e6
--- /dev/null
+++ b/llava/utils/io.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import json
+import os
+import pickle
+from contextlib import contextmanager
+from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, TextIO, Union
+
+import numpy as np
+import torch
+import yaml
+
+__all__ = [
+ "load",
+ "save",
+ "load_json",
+ "save_json",
+ "load_jsonl",
+ "save_jsonl",
+ "load_mat",
+ "save_mat",
+ "load_npy",
+ "save_npy",
+ "load_npz",
+ "save_npz",
+ "load_pt",
+ "save_pt",
+ "load_yaml",
+ "save_yaml",
+]
+
+
+@contextmanager
+def file_descriptor(f: Union[str, IO], mode: str = "r") -> Iterator[IO]:
+ opened = False
+ try:
+ if isinstance(f, str):
+ f = open(f, mode)
+ opened = True
+ yield f
+ finally:
+ if opened:
+ f.close()
+
+
+def load_json(f: Union[str, TextIO], **kwargs) -> Any:
+ with file_descriptor(f, mode="r") as fd:
+ return json.load(fd, **kwargs)
+
+
+def save_json(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
+ with file_descriptor(f, mode="w") as fd:
+ json.dump(obj, fd, **kwargs)
+
+
+def load_jsonl(f: Union[str, TextIO], **kwargs) -> Any:
+ with file_descriptor(f, mode="r") as fd:
+ return [json.loads(datum, **kwargs) for datum in fd.readlines()]
+
+
+def save_jsonl(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
+ with file_descriptor(f, mode="w") as fd:
+ fd.write("\n".join(json.dumps(datum, **kwargs) for datum in obj))
+
+
+def load_mat(f: Union[str, BinaryIO], **kwargs) -> Any:
+ import scipy.io
+
+ return scipy.io.loadmat(f, **kwargs)
+
+
+def save_mat(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
+ import scipy.io
+
+ scipy.io.savemat(f, obj, **kwargs)
+
+
+def load_npy(f: Union[str, BinaryIO], **kwargs) -> Any:
+ return np.load(f, **kwargs)
+
+
+def save_npy(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
+ np.save(f, obj, **kwargs)
+
+
+def load_npz(f: Union[str, BinaryIO], **kwargs) -> Any:
+ return np.load(f, **kwargs)
+
+
+def save_npz(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
+ np.savez(f, obj, **kwargs)
+
+
+def load_pkl(f: Union[str, BinaryIO], **kwargs) -> Any:
+ with file_descriptor(f, mode="rb") as fd:
+ try:
+ return pickle.load(fd, **kwargs)
+ except UnicodeDecodeError:
+ if "encoding" in kwargs:
+ raise
+ fd.seek(0)
+ return pickle.load(fd, encoding="latin1", **kwargs)
+
+
+def save_pkl(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
+ with file_descriptor(f, mode="wb") as fd:
+ pickle.dump(obj, fd, **kwargs)
+
+
+def load_pt(f: Union[str, BinaryIO], **kwargs) -> Any:
+ return torch.load(f, **kwargs)
+
+
+def save_pt(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
+ torch.save(obj, f, **kwargs)
+
+
+def load_yaml(f: Union[str, TextIO]) -> Any:
+ with file_descriptor(f, mode="r") as fd:
+ return yaml.safe_load(fd)
+
+
+def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
+ with file_descriptor(f, mode="w") as fd:
+ yaml.safe_dump(obj, fd, **kwargs)
+
+
+def load_txt(f: Union[str, TextIO]) -> Any:
+ with file_descriptor(f, mode="r") as fd:
+ return fd.read()
+
+
+def save_txt(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
+ with file_descriptor(f, mode="w") as fd:
+ fd.write(obj)
+
+
+__io_registry: Dict[str, Dict[str, Callable]] = {
+ ".txt": {"load": load_txt, "save": save_txt},
+ ".json": {"load": load_json, "save": save_json},
+ ".jsonl": {"load": load_jsonl, "save": save_jsonl},
+ ".mat": {"load": load_mat, "save": save_mat},
+ ".npy": {"load": load_npy, "save": save_npy},
+ ".npz": {"load": load_npz, "save": save_npz},
+ ".pkl": {"load": load_pkl, "save": save_pkl},
+ ".pt": {"load": load_pt, "save": save_pt},
+ ".pth": {"load": load_pt, "save": save_pt},
+ ".pth.tar": {"load": load_pt, "save": save_pt},
+ ".yaml": {"load": load_yaml, "save": save_yaml},
+ ".yml": {"load": load_yaml, "save": save_yaml},
+}
+
+
+def load(fpath: str, **kwargs) -> Any:
+ assert isinstance(fpath, str), type(fpath)
+
+ for extension in sorted(__io_registry.keys(), key=len, reverse=True):
+ if fpath.endswith(extension) and "load" in __io_registry[extension]:
+ return __io_registry[extension]["load"](fpath, **kwargs)
+
+ raise NotImplementedError(f'"{fpath}" cannot be loaded.')
+
+
+def save(fpath: str, obj: Any, **kwargs) -> None:
+ assert isinstance(fpath, str), type(fpath)
+ os.makedirs(os.path.dirname(fpath), exist_ok=True)
+
+ for extension in sorted(__io_registry.keys(), key=len, reverse=True):
+ if fpath.endswith(extension) and "save" in __io_registry[extension]:
+ __io_registry[extension]["save"](fpath, obj, **kwargs)
+ return
+
+ raise NotImplementedError(f'"{fpath}" cannot be saved.')
diff --git a/llava/utils/logging.py b/llava/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..54487a2f7db0dcb38a5fca831888b34072aaa72f
--- /dev/null
+++ b/llava/utils/logging.py
@@ -0,0 +1,23 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import typing
+
+if typing.TYPE_CHECKING:
+ from loguru import Logger
+else:
+ Logger = None
+
+__all__ = ["logger"]
+
+
+def __get_logger() -> Logger:
+ from loguru import logger
+
+ return logger
+
+
+logger = __get_logger()
diff --git a/llava/utils/media.py b/llava/utils/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12e6631462c75ca5b24fcd90d2d840eaae7d998
--- /dev/null
+++ b/llava/utils/media.py
@@ -0,0 +1,294 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import glob
+import os
+import tempfile
+from collections import defaultdict
+from io import BytesIO
+from typing import Any, Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+import PIL
+import PIL.Image
+import requests
+from transformers import PretrainedConfig
+from pydub import AudioSegment
+
+from llava.constants import MEDIA_TOKENS
+from llava.media import Image, Video, Speech, Sound
+from llava.utils import make_list
+from llava.utils.logging import logger
+import torch
+import whisper
+import soundfile as sf
+from librosa import resample as librosa_resample
+from transformers import AutoFeatureExtractor
+import math
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
+import kaldiio
+# wav_processor = AutoFeatureExtractor.from_pretrained('pretrained_models/AF-Whisper')
+wav_processor = AutoFeatureExtractor.from_pretrained('Qwen/Qwen2-Audio-7B')
+
+__all__ = ["extract_media"]
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+
+def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
+ if isinstance(image, Image):
+ if image.path.startswith("http://") or image.path.startswith("https://"):
+ image = PIL.Image.open(requests.get(image.path, stream=True).raw)
+ else:
+ image = PIL.Image.open(image.path)
+ return image
+
+
+def _load_video_bytesio(video_bytesio: BytesIO, *, num_frames: int) -> List[PIL.Image.Image]:
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
+ temp_video.write(video_bytesio.read())
+ temp_video_name = temp_video.name
+ return _load_video(temp_video_name, num_frames=num_frames)
+
+
+def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
+ # Load video frames from a directory
+ if os.path.isdir(video_path):
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
+
+ # Load video frames from a video file
+ vidcap = cv2.VideoCapture(video_path)
+
+ # Find the last frame as frame count might not be accurate
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+ while frame_count > 0:
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
+ if vidcap.grab():
+ break
+ frame_count -= 1
+ else:
+ raise ValueError(f"Video '{video_path}' has no frames.")
+
+ # Extract frames uniformly
+ indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
+ frames = {}
+ for index in indices:
+ if index in frames:
+ continue
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
+ success, frame = vidcap.read()
+ if not success:
+ logger.warning(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
+ continue
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames[index] = PIL.Image.fromarray(frame)
+ return [frames[index] for index in indices if index in frames]
+
+
+def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
+ num_frames = config.num_video_frames
+ if getattr(config, "fps") != 0:
+ logger.warning("Extracting frames from video with specified FPS is not supported yet. Ignored.")
+ if isinstance(video.path, BytesIO):
+ frames = _load_video_bytesio(video.path, num_frames=num_frames)
+ else:
+ frames = _load_video(video.path, num_frames=num_frames)
+ return frames
+
+def _load_speech(speech_path: str):
+ # Load video frames from a directory
+ if speech_path is None:
+ return None
+ speech_outputs = []
+
+ speech = whisper.load_audio(speech_path)
+ speech = whisper.pad_or_trim(speech)
+ mel = whisper.log_mel_spectrogram(speech)
+ speech_outputs.append(mel.unsqueeze(0))
+ speech_frames = torch.stack(speech_outputs, dim=0)
+ return speech_frames.numpy().tolist()
+
+def _extract_speech(speech: Speech, config: PretrainedConfig):
+ frames = _load_speech(speech.path)
+ return frames
+
+def _get_num_windows(T, sr):
+
+ window_length = int(30.0 * sr)
+ window_overlap = int(0.0 * sr)
+ max_num_window = 20
+ num_windows = 1
+ if T <= window_length:
+ num_windows = 1
+ full_length = window_length
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
+ num_windows = max_num_window
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
+ else:
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
+
+ return num_windows, full_length
+
+def _load_audio(file_path, target_sr=16000, duration=30.0, start=0.0):
+ if file_path.endswith('.mp3'):
+ audio = AudioSegment.from_file(file_path)
+ if len(audio) > (start + duration) * 1000:
+ audio = audio[start * 1000:(start + duration) * 1000]
+
+ if audio.frame_rate != target_sr:
+ audio = audio.set_frame_rate(target_sr)
+
+ if audio.channels > 1:
+ audio = audio.set_channels(1)
+
+ data = np.array(audio.get_array_of_samples())
+ if audio.sample_width == 2:
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
+ elif audio.sample_width == 4:
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
+ else:
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
+
+ else:
+ with sf.SoundFile(file_path) as audio:
+ original_sr = audio.samplerate
+ channels = audio.channels
+
+ max_frames = int((start + duration) * original_sr)
+
+ audio.seek(int(start * original_sr))
+ frames_to_read = min(max_frames, len(audio))
+ data = audio.read(frames_to_read)
+
+ if data.max() > 1 or data.min() < -1:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ if original_sr != target_sr:
+ if channels == 1:
+ data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
+ else:
+ data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
+ else:
+ if channels != 1:
+ data = data.T[0]
+
+ if data.min() >= 0:
+ data = 2 * data / abs(data.max()) - 1.0
+ else:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ assert len(data.shape) == 1, data.shape
+ return data
+
+
+def _load_sound_mask(sound_file, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=20, audio_start = 0.0):
+ if sound_file is None:
+ return None
+ window_length = int(window_length * sample_rate)
+ window_overlap = int(window_overlap * sample_rate)
+ max_num_window = int(max_num_window)
+ duration = max_num_window * (window_length - window_overlap) + window_overlap
+
+ sound_outputs = []
+ audio_feature_masks = []
+ audio_embed_masks = []
+
+ try:
+ audio_data = _load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration
+ T = len(audio_data)
+ audio_data = audio_data.reshape(1, -1)
+ num_windows, full_length = _get_num_windows(T, sample_rate)
+
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
+ for i in range(num_windows):
+ audio_embed_mask = torch.zeros(750)
+ start = i * (window_length - window_overlap)
+ audio_data_tensor_this = audio_data_tensor[:, start:start+window_length]
+ orig_length = audio_data_tensor_this.shape[1]
+ audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") #
+ sound_outputs.append(audio_data_tensor_this["input_features"])
+ # calculate the mask for the input melspec to Whisper
+ melspec_frames_this_window = int(math.ceil(orig_length / 160))
+ feature_attention_mask = torch.zeros(3000, dtype=torch.int32)
+ feature_attention_mask[:melspec_frames_this_window] = 1
+ audio_feature_masks.append(feature_attention_mask.unsqueeze(0))
+ # calculate the mask for the output embedding for use in AF3
+ conv_lengths = (melspec_frames_this_window - 1) // 2 + 1
+ output_embedding_lengths = (conv_lengths - 2) // 2 + 1
+ audio_embed_mask[:output_embedding_lengths] = 1
+ audio_embed_masks.append(audio_embed_mask)
+ except:
+ print("Error loading sound file: ", sound_file)
+ sound_outputs.append(torch.zeros(1,128,3000))
+ audio_feature_masks.append(torch.zeros(1, 3000, dtype=torch.int32))
+ audio_embed_masks.append(torch.zeros(750))
+ sound_outputs = torch.stack(sound_outputs, dim=0)
+ audio_feature_masks = torch.stack(audio_feature_masks, dim=0)
+ audio_embed_masks = torch.stack(audio_embed_masks, dim=0)
+ return sound_outputs.numpy().tolist(), audio_feature_masks ,audio_embed_masks
+
+
+def _extract_sound_mask(sound: Sound, config: PretrainedConfig):
+ frames, audio_feature_masks, audio_embed_masks = _load_sound_mask(sound.path)
+ return frames, audio_feature_masks, audio_embed_masks
+
+def extract_media(
+ messages: List[Dict[str, Any]],
+ config: Optional[PretrainedConfig] = None,
+ draft: bool = False,
+) -> Dict[str, List[Any]]:
+ media = defaultdict(list)
+ media_meta = defaultdict(list)
+ for message in messages:
+ text = ""
+ print(message)
+ for part in make_list(message["value"]):
+ if isinstance(part, str):
+ for token in MEDIA_TOKENS.values():
+ if token in part:
+ logger.warning(f"Media token '{token}' found in text: '{part}'. Removed.")
+ part = part.replace(token, "").strip()
+ text += part
+ if isinstance(part, (Image, PIL.Image.Image)):
+ if draft:
+ media["image"].append(part)
+ else:
+ media["image"].append(_extract_image(part))
+ text += MEDIA_TOKENS["image"]
+ if isinstance(part, Video):
+ if draft:
+ media["video"].append(part)
+ else:
+ media["video"].append(_extract_video(part, config))
+ text += MEDIA_TOKENS["video"]
+ if isinstance(part, Speech):
+ if draft:
+ media["speech"].append(part)
+ else:
+ media["speech"].append(_extract_speech(part, config))
+ text += MEDIA_TOKENS["speech"]
+ if isinstance(part, Sound):
+ if draft:
+ media["sound"].append(part)
+ else:
+ sound, audio_feature_masks,audio_embed_masks = _extract_sound_mask(part, config)
+ media["sound"].append(sound)
+ media_meta["sound_feature_masks"].append(audio_feature_masks)
+ media_meta["sound_embed_masks"].append(audio_embed_masks)
+ text += MEDIA_TOKENS["sound"] * len(sound)
+
+ message["value"] = text
+ return media, media_meta
diff --git a/llava/utils/merge_lora_weights_and_save_hf_model.py b/llava/utils/merge_lora_weights_and_save_hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5296e13430ed9c13ac0285c692faf08631183da0
--- /dev/null
+++ b/llava/utils/merge_lora_weights_and_save_hf_model.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/dvlab-research/LongLoRA
+
+import argparse
+import os
+from typing import Dict
+
+import torch
+import transformers
+from peft import PeftModel
+
+
+def parse_config():
+ parser = argparse.ArgumentParser(description="arg parser")
+ parser.add_argument("--base_model", type=str, default="/data/pretrained-models/llama-7b-hf")
+ parser.add_argument("--peft_model", type=str, default=None, help="")
+ parser.add_argument("--save_path", type=str, default=None, help="")
+ parser.add_argument("--cache_dir", type=str, default=None, help="./cache_dir")
+ parser.add_argument("--rope_theta", type=int, default=15300000, help="")
+ parser.add_argument("--max_position_embeddings", type=int, default=65536, help="")
+ args = parser.parse_args()
+ return args
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def main():
+ args = parse_config()
+ device = "cuda:0"
+ torch.cuda.set_device(device)
+
+ print("base model", args.base_model)
+ print("peft model", args.peft_model)
+
+ config = transformers.AutoConfig.from_pretrained(
+ args.base_model,
+ cache_dir=args.cache_dir,
+ )
+
+ config.rope_theta = args.rope_theta
+ config.max_position_embeddings = args.max_position_embeddings
+ config.model_max_length = args.max_position_embeddings
+ config.tokenizer_model_max_length = args.max_position_embeddings
+
+ # Load model and tokenizer
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ args.base_model,
+ config=config,
+ cache_dir=args.cache_dir,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ args.base_model,
+ )
+
+ model = PeftModel.from_pretrained(
+ model,
+ args.peft_model,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+ model = model.merge_and_unload()
+ model.save_pretrained(args.save_path)
+ tokenizer.save_pretrained(args.save_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llava/utils/tokenizer.py b/llava/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c18296eca6607b7a588a61b8386a3c6707d81940
--- /dev/null
+++ b/llava/utils/tokenizer.py
@@ -0,0 +1,189 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Dict, List, Optional, Sequence
+
+import torch
+import transformers
+
+from llava import conversation as conversation_lib
+from llava.constants import IGNORE_INDEX, SENTINEL_TOKEN
+from llava.mm_utils import tokenizer_image_token
+from llava.utils.logging import logger
+
+__all__ = [
+ "tokenize_conversation",
+ "preprocess_conversation",
+ "infer_stop_tokens",
+]
+
+DUMMY_CONVERSATION = [
+ {"from": "human", "value": "question"},
+ {"from": "gpt", "value": "answer"},
+] * 10
+
+
+def tokenize_conversation_legacy(
+ messages: Sequence[Dict[str, str]],
+ tokenizer: transformers.PreTrainedTokenizer,
+ add_generation_prompt: bool = False,
+ overrides: Optional[Dict[str, str]] = None,
+ no_system_prompt: bool = False,
+) -> torch.Tensor:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ if no_system_prompt:
+ conv.system = ""
+
+ # Skip the first message if it is not from human
+ if messages[0]["from"] != "human":
+ messages = messages[1:]
+
+ # Add a generation prompt if needed
+ if add_generation_prompt:
+ messages.append({"from": "gpt", "value": None})
+
+ conv.messages = []
+ for turn, message in enumerate(messages):
+ role = roles[message["from"]]
+ assert role == conv.roles[turn % 2]
+ if overrides is not None and message["from"] in overrides:
+ conv.append_message(role, overrides[message["from"]])
+ else:
+ conv.append_message(role, message["value"])
+
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
+
+
+def tokenize_conversation(
+ messages: Sequence[Dict[str, str]],
+ tokenizer: transformers.PreTrainedTokenizer,
+ add_generation_prompt: bool = False,
+ overrides: Optional[Dict[str, str]] = None,
+ no_system_prompt: bool = False,
+) -> torch.Tensor:
+ # Normalize the conversation before tokenization
+ for message in messages:
+ message["value"] = message["value"].strip()
+
+ if conversation_lib.default_conversation.sep_style != conversation_lib.SeparatorStyle.AUTO:
+ return tokenize_conversation_legacy(
+ messages,
+ tokenizer,
+ add_generation_prompt=add_generation_prompt,
+ overrides=overrides,
+ no_system_prompt=no_system_prompt,
+ )
+
+ conversation = []
+ for m in messages:
+ message = {}
+ if m["from"] == "human":
+ message["role"] = "user"
+ elif m["from"] == "gpt":
+ message["role"] = "assistant"
+ else:
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
+
+ message["content"] = m["value"]
+ if overrides is not None and m["from"] in overrides:
+ message["content"] = overrides[m["from"]]
+ conversation.append(message)
+
+ if no_system_prompt:
+ conversation = [{"role": "system", "content": ""}] + conversation
+
+ text = tokenizer.apply_chat_template(
+ conversation,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False,
+ )
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
+
+
+def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
+ if not hasattr(tokenizer, "sentinel_token"):
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
+ tokenizer.sentinel_token = SENTINEL_TOKEN
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
+
+
+def preprocess_conversation(
+ conversation: Sequence[Dict[str, str]],
+ tokenizer: transformers.PreTrainedTokenizer,
+ no_system_prompt: bool = False,
+ retried: bool = False,
+) -> Dict[str, Any]:
+ inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
+ labels = torch.ones_like(inputs) * IGNORE_INDEX
+
+ # Generate the template by replacing the assistant's response with a sentinel.
+ _maybe_add_sentinel_token(tokenizer)
+ template = tokenize_conversation(
+ conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
+ )
+
+ # Remove sentinel tokens from the template.
+ mask = torch.ones_like(template, dtype=torch.bool)
+ for k in range(template.size(0) - 1):
+ if template[k] == tokenizer.sentinel_token_id:
+ mask[k : k + 2] = False
+ # NOTE(zhijianl): This is to handle the corner case where there is an empty token before the sentinel token.
+ if k > 0 and retried:
+ mask[k - 1] = False
+ template = template[mask]
+
+ # Match the tokenized conversation with the template (with no assistant's response).
+ # Every token that is not matched will be included in the label for training.
+ p = 0
+ for k in range(inputs.size(0)):
+ if p < template.size(0) and inputs[k] == template[p]:
+ p += 1
+ else:
+ labels[k] = inputs[k]
+
+ # Mask all tokens in the label if the template is not fully matched.
+ if p < template.size(0):
+ if not retried:
+ return preprocess_conversation(
+ conversation,
+ tokenizer,
+ no_system_prompt=no_system_prompt,
+ retried=True,
+ )
+ logger.error(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
+ labels[:] = IGNORE_INDEX
+
+ return {"input_ids": inputs, "labels": labels}
+
+
+def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
+ _maybe_add_sentinel_token(tokenizer)
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
+
+ stop_tokens = {tokenizer.eos_token}
+ for k in range(template.size(0) - 1):
+ if template[k] == tokenizer.sentinel_token_id:
+ stop_token = tokenizer.decode(template[k + 1])
+ stop_tokens.add(stop_token)
+ return list(stop_tokens)
diff --git a/llava/utils/utils.py b/llava/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd28e5c2763a3f308751ee92a7d685f4d144ba8
--- /dev/null
+++ b/llava/utils/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, List
+
+__all__ = ["make_list", "disable_torch_init"]
+
+
+def make_list(obj: Any) -> List:
+ return obj if isinstance(obj, list) else [obj]
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2d67791e4acddd0cc6432d5532809666e4e4204c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,34 @@
+# Core dependencies
+uv
+hydra-core
+loguru
+Pillow
+pydub
+torch
+torchvision
+
+
+# Transformers and training utilities
+transformers==4.46.0
+pytorchvideo==0.1.5
+deepspeed==0.15.4
+accelerate==0.34.2
+numpy==1.26.4
+opencv-python-headless==4.8.0.76
+matplotlib
+
+# Audio
+soundfile
+librosa
+openai-whisper
+ftfy
+ffmpeg
+jiwer
+einops
+wandb
+kaldiio
+peft==0.14.0
+
+# Compatibility fix
+protobuf==3.20.*
+triton==3.1.0
\ No newline at end of file
diff --git a/static/af3_main_diagram-1.png b/static/af3_main_diagram-1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0cf0f8dc986b77cd4b220514cba07d0e85307ac3
--- /dev/null
+++ b/static/af3_main_diagram-1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3908e928f7df50f1860d05d4e31343c44d46550c46bfaff4bb2cf93cea48fd14
+size 229099
diff --git a/static/af3_radial-1.png b/static/af3_radial-1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7bc3da1afaee9d2363df33b66feae1228023a1c
--- /dev/null
+++ b/static/af3_radial-1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c4f1ede2dd47b45c40996a2827cb34312662f81d92c85d71eec596f9f06631e
+size 144906
diff --git a/static/af3_sota.png b/static/af3_sota.png
new file mode 100644
index 0000000000000000000000000000000000000000..29d68ae403c1cf116cec34fb1f6405110d14e778
--- /dev/null
+++ b/static/af3_sota.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1dce527eb4625a47f2819388a0d9d05978a3d9287b839c85a41058e3b0c1208f
+size 266220
diff --git a/static/audio/audio2.wav b/static/audio/audio2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6a82a86a93ad0feee4d795843d081a075a51818f
--- /dev/null
+++ b/static/audio/audio2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1800465f4c47fb96e5852d760d8482c96556e9ab5382909ab4cedde4a81fb6dd
+size 12068618
diff --git a/static/chat/audio1.mp3 b/static/chat/audio1.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..d03ef309274b8524454bb1c1ca835b6029043d2d
--- /dev/null
+++ b/static/chat/audio1.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:927a50999b06065d1e2d728710e54e4205efce7714970f53c1f4e355923034a0
+size 480621
diff --git a/static/chat/audio2.mp3 b/static/chat/audio2.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..d7f8fa55da072304a9745f8adc956245864e3d37
--- /dev/null
+++ b/static/chat/audio2.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a58b3ae4e71f0cfb54fbdcb3bb22c710b61dcb6865d3289c4a86054d0648653
+size 481005
diff --git a/static/emergent/audio1.wav b/static/emergent/audio1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..63f3ff1fe21de7cc319948818fe194f41269728d
--- /dev/null
+++ b/static/emergent/audio1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9207c04cd5d40eeef532a3ced07c5986f6afd4e9a1d908a0d5fd043e120c7375
+size 880684
diff --git a/static/logo-no-bg.png b/static/logo-no-bg.png
new file mode 100644
index 0000000000000000000000000000000000000000..d18ebe4ce6f29f446fdded83c88502186cdff347
--- /dev/null
+++ b/static/logo-no-bg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4868114d109a9939114016c32691405f6d8dfdcdccb29dfce63e52f0b6e9540
+size 464155
diff --git a/static/speech/.DS_Store b/static/speech/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/static/speech/.DS_Store differ
diff --git a/static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav b/static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav
new file mode 100644
index 0000000000000000000000000000000000000000..343b27bdb93af24e00c6145815dadb0cb094d8cd
--- /dev/null
+++ b/static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac4d1c8df03a0dc524158cf9336cb50eb26d4d0467f399bc0693eaace1b9cb5b
+size 5574700
diff --git a/static/speech/audio3.wav b/static/speech/audio3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b7b355a244d1371cf45a996aadccb956512f8669
--- /dev/null
+++ b/static/speech/audio3.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c04e07a8fc6013868732ef4a8142c0e131e8dd9b90e451f3bb848e70de42d33
+size 439340
diff --git a/static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav b/static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4534a0ff902ad5325cc20078fb2e154e3fc9ef10
--- /dev/null
+++ b/static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de944c7b2c4289767ae5ef5edfcb1e2136c0e88f72c0ed68ac50b3274f839f3b
+size 5955628
diff --git a/static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav b/static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a0e0093f43bee1c5a2591f8f3a812556db49007f
--- /dev/null
+++ b/static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:293253bdb0ab2023ac9939975d01a36703929fee983aa75bc26eb476070eac8a
+size 17192060
diff --git a/static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav b/static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..52e888aee47db7251a881b2e10601a2032aa1983
--- /dev/null
+++ b/static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a6048682cfadbbaf2041b03a30d2059f47ba1b3e2f36d566b7ead08634ee9254
+size 134446
diff --git a/static/speech/speaker1.flac b/static/speech/speaker1.flac
new file mode 100644
index 0000000000000000000000000000000000000000..d67b3e6a4749c27291f5b0bcc5492789c84a8261
--- /dev/null
+++ b/static/speech/speaker1.flac
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfcd32cc9eac068f3c796a3f7c421b7a11d845b0c2e318563b88bc9f2ff34790
+size 446286
diff --git a/static/speech/videoplayback.wav b/static/speech/videoplayback.wav
new file mode 100644
index 0000000000000000000000000000000000000000..14c70812a75ad89aca760ff4514cc77972adb9cd
--- /dev/null
+++ b/static/speech/videoplayback.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a375126827221be8e01e6e8d287b0850d81d2b00c3093077be5de090fedfbb8
+size 32190542
diff --git a/static/think/audio1.wav b/static/think/audio1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..af62190e46f5a37525274cadc21828c987a0ed5e
--- /dev/null
+++ b/static/think/audio1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae56b7534f55f16104bc611595e771993f9d2d2db92b52534916ecdeecfda779
+size 4608078
diff --git a/static/think/audio2.wav b/static/think/audio2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f571796955bca4ab46495a5a51bbef70252b1b5d
--- /dev/null
+++ b/static/think/audio2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f151cfdb88d1b8ea37dd2399cddb42eb9dc4e67698109574c32653e02128221
+size 1920078
diff --git a/static/voice/voice_0.mp3 b/static/voice/voice_0.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..064f934b98003ba9afeea249d63945bde1109e34
Binary files /dev/null and b/static/voice/voice_0.mp3 differ
diff --git a/static/voice/voice_1.mp3 b/static/voice/voice_1.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..b2d684841ee81de16644af0aa9f1783fc7c98f7f
Binary files /dev/null and b/static/voice/voice_1.mp3 differ
diff --git a/static/voice/voice_2.mp3 b/static/voice/voice_2.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..005f644b30784b26572d9f874226d7f03841577b
--- /dev/null
+++ b/static/voice/voice_2.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59ed5f5d55bffd6728d26936ec733b8a01c0e5fcbab5f0c859fd37e56198e069
+size 203040