diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..513e66581496adc0ee5d90eddc42cabe83436e2b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,57 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o filter=lfs diff=lfs merge=lfs -text +static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text +static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text +static/af3_sota.png filter=lfs diff=lfs merge=lfs -text +static/audio/audio2.wav filter=lfs diff=lfs merge=lfs -text +static/chat/audio1.mp3 filter=lfs diff=lfs merge=lfs -text +static/chat/audio2.mp3 filter=lfs diff=lfs merge=lfs -text +static/emergent/audio1.wav filter=lfs diff=lfs merge=lfs -text +static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text +static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav filter=lfs diff=lfs merge=lfs -text +static/speech/audio3.wav filter=lfs diff=lfs merge=lfs -text +static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav filter=lfs diff=lfs merge=lfs -text +static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav filter=lfs diff=lfs merge=lfs -text +static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav filter=lfs diff=lfs merge=lfs -text +static/think/audio1.wav filter=lfs diff=lfs merge=lfs -text +static/think/audio2.wav filter=lfs diff=lfs merge=lfs -text +static/voice/voice_2.mp3 filter=lfs diff=lfs merge=lfs -text +static/speech/speaker1.flac filter=lfs diff=lfs merge=lfs -text +static/speech/videoplayback.wav filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5007c25dfd9c5976effddde197cfe2089ef8d035 --- /dev/null +++ b/README.md @@ -0,0 +1,147 @@ +--- +license: other +title: Audio Flamingo 3 Demo +sdk: gradio +emoji: ๐Ÿš€ +colorFrom: green +colorTo: green +pinned: true +short_description: Audio Flamingo 3 Demo +--- + +
+ + Audio Flamingo 3 ๐Ÿ”ฅ๐Ÿš€๐Ÿ”ฅ + +
+
+

+ Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models +

+
+ +
+ + + + +
+ +
+ + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +## Overview + +This repo contains the PyTorch implementation of [Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models](). Audio Flamingo 3 (AF3) is a fully open, state-of-the-art Large Audio-Language Model (LALM) that advances reasoning and understanding across speech, sounds, and music. AF3 builds on previous work with innovations in: + +- Unified audio representation learning (speech, sound, music) +- Flexible, on-demand chain-of-thought reasoning (Thinking in Audio) +- Long-context audio comprehension (including speech and up to 10 minutes) +- Multi-turn, multi-audio conversational dialogue (AF3-Chat) +- Voice-to-voice interaction (AF3-Chat) + +Extensive evaluations confirm AF3โ€™s effectiveness, setting new benchmarks on over 20 public audio understanding and reasoning tasks. + + +## Main Results + +Audio Flamingo 3 outperforms prior SOTA models including GAMA, Audio Flamingo, Audio Flamingo 2, Qwen-Audio, Qwen2-Audio, Qwen2.5-Omni.LTU, LTU-AS, SALMONN, AudioGPT, Gemini Flash v2 and Gemini Pro v1.5 on a number of understanding and reasoning benchmarks. + +
+ +
+ +
+ +
+ +## Audio Flamingo 3 Architecture + +Audio Flamingo 3 uses AF-Whisper unified audio encoder, MLP-based audio adaptor, Decoder-only LLM backbone (Qwen2.5-7B), and Streaming TTS module (AF3-Chat). +Audio Flamingo 3 can take up to 10 minutes of audio inputs. + +
+ +
+ +## Installation + +```bash +./environment_setup.sh af3 +``` + +## Code Structure + +- The folder ```audio_flamingo_3/``` contains the main training and inference code of Audio Flamingo 3. +- The folder ```audio_flamingo_3/scripts``` contains the inference scripts of Audio Flamingo 3 in case you would like to use our pretrained checkpoints on HuggingFace. + +Each folder is self-contained and we expect no cross dependencies between these folders. This repo does not contain the code for Streaming-TTS pipeline which will released in the near future. + +## Single Line Inference + +To infer stage 3 model directly, run the command below: +```bash +python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav +``` + +To infer the model in stage 3.5 model, run the command below: +```bash +python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --model-path /path/to/checkpoint/af3-7b/stage35 --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav --peft-mode +``` + +## References + +The main training and inferencing code within each folder are modified from [NVILA](https://github.com/NVlabs/VILA/tree/main) [Apache license](incl_licenses/License_1.md). + +## License + +- The code in this repo is under [MIT license](incl_licenses/MIT_license.md). +- The checkpoints are for non-commercial use only [NVIDIA OneWay Noncommercial License](incl_licenses/NVIDIA_OneWay_Noncommercial_License.docx). They are also subject to the [Qwen Research license](https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/LICENSE), the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset. +- Notice: Audio Flamingo 3 is built with Qwen-2.5. Qwen is licensed under the Qwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved. + + +## Citation + +- Audio Flamingo 2 +``` +@article{ghosh2025audio, + title={Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities}, + author={Ghosh, Sreyan and Kong, Zhifeng and Kumar, Sonal and Sakshi, S and Kim, Jaehyeon and Ping, Wei and Valle, Rafael and Manocha, Dinesh and Catanzaro, Bryan}, + journal={arXiv preprint arXiv:2503.03983}, + year={2025} +} +``` + +- Audio Flamingo +``` +@inproceedings{kong2024audio, + title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities}, + author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan}, + booktitle={International Conference on Machine Learning}, + pages={25125--25148}, + year={2024}, + organization={PMLR} +} +``` \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed35e40c41d5cbcc5e0085abb7623f9dc409f7d --- /dev/null +++ b/app.py @@ -0,0 +1,306 @@ +import gradio as gr +import torch +import llava +from peft import PeftModel +import os +from huggingface_hub import snapshot_download +import copy + +# --------------------------------- +# SINGLE-TURN MODEL SETUP +# --------------------------------- + +MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3") +MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35') + +model_single = llava.load(MODEL_BASE_SINGLE, model_base=None) +model_single_copy = copy.deepcopy(model_single) + +# Move the model to GPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model_single = model_single.to_empty(device=device) + +generation_config_single = model_single.default_generation_config + +model_think = PeftModel.from_pretrained( + model_single, + MODEL_BASE_THINK, + device_map="auto", + torch_dtype=torch.float16, +) +model_think.to(device) + +# # --------------------------------- +# # MULTI-TURN MODEL SETUP +# # --------------------------------- +# MODEL_BASE_MULTI = snapshot_download(repo_id="nvidia/audio-flamingo-3-chat") +# model_multi = llava.load(MODEL_BASE_MULTI, model_base=None, devices=[0]) +# generation_config_multi = model_multi.default_generation_config + + +# --------------------------------- +# SINGLE-TURN INFERENCE FUNCTION +# --------------------------------- +def single_turn_infer(audio_file, prompt_text): + try: + sound = llava.Sound(audio_file) + full_prompt = f"\n{prompt_text}" + response = model_single_copy.generate_content([sound, full_prompt], generation_config=generation_config_single) + return response + except Exception as e: + return f"โŒ Error: {str(e)}" + +# --------------------------------- +# MULTI-TURN INFERENCE FUNCTION +# --------------------------------- +# def multi_turn_chat(user_input, audio_file, history, current_audio): +# try: +# if audio_file is not None: +# current_audio = audio_file # Update state if a new file is uploaded + +# if current_audio is None: +# return history + [("System", "โŒ Please upload an audio file before chatting.")], history, current_audio + +# sound = llava.Sound(current_audio) +# prompt = f"\n{user_input}" + +# response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi) + +# history.append((user_input, response)) +# return history, history, current_audio +# except Exception as e: +# history.append((user_input, f"โŒ Error: {str(e)}")) +# return history, history, current_audio + +def think_infer(audio_file, prompt_text): + try: + sound = llava.Sound(audio_file) + full_prompt = f"\n{prompt_text}" + response = model_think.generate_content([sound, full_prompt], generation_config=generation_config_single) + return response + except Exception as e: + return f"โŒ Error: {str(e)}" + +# --------------------------------- +# MULTI-TURN INFERENCE FUNCTION +# --------------------------------- +# def multi_turn_chat(user_input, audio_file, history, current_audio): +# try: +# if audio_file is not None: +# current_audio = audio_file # Update state if a new file is uploaded + +# if current_audio is None: +# return history + [("System", "โŒ Please upload an audio file before chatting.")], history, current_audio + +# sound = llava.Sound(current_audio) +# prompt = f"\n{user_input}" + +# response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi) + +# history.append((user_input, response)) +# return history, history, current_audio +# except Exception as e: +# history.append((user_input, f"โŒ Error: {str(e)}")) +# return history, history, current_audio +# --------------------------------- +# INTERFACE +# --------------------------------- +with gr.Blocks(css=""" +.gradio-container { + max-width: 100% !important; + width: 100% !important; + margin: 0 !important; + padding: 0 !important; +} +#component-0, .gr-block.gr-box { + width: 100% !important; +} +.gr-block.gr-box, .gr-column, .gr-row { + padding: 0 !important; + margin: 0 !important; +} +""") as demo: + + with gr.Column(): + gr.HTML(""" +
+ Audio Flamingo 3 Logo +

Audio Flamingo 3

+

Advancing Audio Intelligence with Fully Open Large Audio-Language Models

+
+ +
+ + arXiv + + + Demo Page + + + GitHub + + + GitHub Stars + +
+
+ + + + + + +
+
+ + + + + + + + + + + + +
+""") + # gr.Markdown("#### NVIDIA (2025)") + + with gr.Tabs(): + # ---------------- SINGLE-TURN ---------------- + with gr.Tab("๐ŸŽฏ Single-Turn Inference"): + with gr.Row(): + with gr.Column(): + audio_input_single = gr.Audio(type="filepath", label="Upload Audio") + prompt_input_single = gr.Textbox(label="Prompt", placeholder="Ask a question about the audio...", lines=8) + btn_single = gr.Button("Generate Answer") + + gr.Examples( + examples=[ + ["static/emergent/audio1.wav", "What is surprising about the relationship between the barking and the music?"], + ["static/audio/audio2.wav", "Please describe the audio in detail."], + ["static/speech/audio3.wav", "Transcribe any speech you hear."], + ], + inputs=[audio_input_single, prompt_input_single], + label="๐Ÿงช Try Examples" + ) + + with gr.Column(): + output_single = gr.Textbox(label="Model Response", lines=15) + + btn_single.click(fn=single_turn_infer, inputs=[audio_input_single, prompt_input_single], outputs=output_single) + with gr.Tab("๐Ÿค” Think / Long"): + + with gr.Row(): + with gr.Column(): + audio_input_think = gr.Audio(type="filepath", label="Upload Audio") + prompt_input_think = gr.Textbox(label="Prompt", placeholder="To enable thinking, please add the text: '\nPlease think and reason about the input music before you respond.' to your prompt.", lines=8) + btn_think = gr.Button("Generate Answer") + + gr.Examples( + examples=[ + ["static/think/audio1.wav", "What are the two people doing in the audio Choose the correct option from the following options:\n(A) One person is demonstrating how to use the equipment\n(B) The two people are discussing how to use the equipment\n(C) The two people are disassembling the equipment\n(D) One person is teaching another person how to use a piece of equipment\n"], + ["static/think/audio2.wav", "Is the boat in the video moving closer or further away? Choose the correct option from the following options:\n(A) Closer\n(B) Further\n"], + ["static/speech/videoplayback.wav", "Generate a detailed caption for the input audio, describing all notable speech, sound, and musical events comprehensively. In the caption, transcribe all spoken content by all speakers in the audio precisely."], + ["static/speech/speaker1.flac", "Transcribe any input speech in the input audio."], + ], + inputs=[audio_input_think, prompt_input_think], + label="๐Ÿงช Try Examples" + ) + + with gr.Column(): + output_think = gr.Textbox(label="Model Response", lines=30) + + btn_think.click(fn=think_infer, inputs=[audio_input_think, prompt_input_think], outputs=output_think) + # ---------------- MULTI-TURN CHAT ---------------- + with gr.Tab("๐Ÿ’ฌ Multi-Turn Chat"): + # chatbot = gr.Chatbot(label="Audio Chatbot") + # audio_input_multi = gr.Audio(type="filepath", label="Upload or Replace Audio Context") + # user_input_multi = gr.Textbox(label="Your message", placeholder="Ask a question about the audio...", lines=8) + # btn_multi = gr.Button("Send") + # history_state = gr.State([]) # Chat history + # current_audio_state = gr.State(None) # Most recent audio file path + + # btn_multi.click( + # fn=multi_turn_chat, + # inputs=[user_input_multi, audio_input_multi, history_state, current_audio_state], + # outputs=[chatbot, history_state, current_audio_state] + # ) + # gr.Examples( + # examples=[ + # ["static/chat/audio1.mp3", "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?"], + # ["static/chat/audio2.mp3", "Switching gears, this one is super energetic and synthetic. If I wanted to remix the calming folk piece into something closer to this, what would you suggest?"], + # ], + # inputs=[audio_input_multi, user_input_multi], + # label="๐Ÿงช Try Examples" + # ) + # Add the link to another Gradio demo here + gr.Markdown("๐Ÿ”— [Check out our other Gradio demo here](https://huggingface.co/spaces/nvidia/audio-flamingo-3-chat)") + + with gr.Tab("๐Ÿ—ฃ๏ธ Speech Prompt"): + # gr.Markdown("Use your **voice** to talk to the model.") + + # with gr.Row(): + # with gr.Column(): + # speech_input = gr.Audio(type="filepath", label="Speak or Upload Audio") + # btn_speech = gr.Button("Submit") + # gr.Examples( + # examples=[ + # ["static/voice/voice_0.mp3"], + # ["static/voice/voice_1.mp3"], + # ["static/voice/voice_2.mp3"], + # ], + # inputs=speech_input, + # label="๐Ÿงช Try Examples" + # ) + # with gr.Column(): + # response_box = gr.Textbox(label="Model Response", lines=15) + + # btn_speech.click(fn=speech_prompt_infer, inputs=speech_input, outputs=response_box) + # Add the link to another Gradio demo here + gr.Markdown("๐Ÿ”— [Check out our other Gradio demo here](https://huggingface.co/spaces/nvidia/audio-flamingo-3-chat)") + + # ---------------- ABOUT ---------------- + with gr.Tab("๐Ÿ“„ About"): + gr.Markdown(""" +### ๐Ÿ“š Overview + +**Audio Flamingo 3** is a fully open state-of-the-art (SOTA) large audio-language model that advances reasoning and understanding across speech, sound, and music. AF3 introduces: + +(i) AF-Whisper, a unified audio encoder trained using a novel strategy for joint representation learning across all 3 modalities of speech, sound, and music; + +(ii) flexible, on-demand thinking, allowing the model to do chain-of-thought reasoning before answering; + +(iii) multi-turn, multi-audio chat; + +(iv) long audio understanding and reasoning (including speech) up to 10 minutes; and + +(v) voice-to-voice interaction. + +To enable these capabilities, we propose several large-scale training datasets curated using novel strategies, including AudioSkills-XL, LongAudio-XL, AF-Think, and AF-Chat, and train AF3 with a novel five-stage curriculum-based training strategy. Trained on only open-source audio data, AF3 achieves new SOTA results on over 20+ (long) audio understanding and reasoning benchmarks, surpassing both open-weight and closed-source models trained on much larger datasets. + +**Key Features:** + +๐Ÿ’ก Audio Flamingo 3 has strong audio, music and speech understanding capabilities. + +๐Ÿ’ก Audio Flamingo 3 supports on-demand thinking for chain-of-though reasoning. + +๐Ÿ’ก Audio Flamingo 3 supports long audio and speech understanding for audios up to 10 minutes. + +๐Ÿ’ก Audio Flamingo 3 can have multi-turn, multi-audio chat with users under complex context. + +๐Ÿ’ก Audio Flamingo 3 has voice-to-voice conversation abilities. + + +""") + + gr.Markdown("ยฉ 2025 NVIDIA | Built with โค๏ธ using Gradio + PyTorch") + + +# ----------------------- +# Launch App +# ----------------------- +if __name__ == "__main__": + demo.launch(share=True) diff --git a/llava/__init__.py b/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c14ac15313d1fceee30d9a3ef2d69bbba5702722 --- /dev/null +++ b/llava/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .entry import * +from .media import * diff --git a/llava/cli/infer_audio.py b/llava/cli/infer_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..8041b2ab9aa52fc166c87ce4527ad5a8356857da --- /dev/null +++ b/llava/cli/infer_audio.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import importlib.util +import json +import os + +from pydantic import BaseModel +from termcolor import colored + +import llava +from llava import conversation as clib +from llava.media import Image, Video, Sound +from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat +from peft import PeftModel +import torch + +def get_schema_from_python_path(path: str) -> str: + schema_path = os.path.abspath(path) + spec = importlib.util.spec_from_file_location("schema_module", schema_path) + schema_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(schema_module) + + # Get the Main class from the loaded module + Main = schema_module.Main + assert issubclass( + Main, BaseModel + ), f"The provided python file {path} does not contain a class Main that describes a JSON schema" + return Main.schema_json() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-base", "-mb", type=str, required=True) + parser.add_argument("--model-path", "-mp", type=str, required=True) + parser.add_argument("--conv-mode", "-c", type=str, default="auto") + parser.add_argument("--text", type=str) + parser.add_argument("--media", type=str, nargs="+") + parser.add_argument("--json-mode", action="store_true") + parser.add_argument("--peft-mode", action="store_true") + parser.add_argument("--json-schema", type=str, default=None) + args = parser.parse_args() + + # Convert json mode to response format + if not args.json_mode: + response_format = None + elif args.json_schema is None: + response_format = ResponseFormat(type="json_object") + else: + schema_str = get_schema_from_python_path(args.json_schema) + print(schema_str) + response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str)) + + # Load model + model = llava.load(args.model_base) + if args.peft_mode: + model = PeftModel.from_pretrained( + model, + args.model_path, + device_map="auto", + torch_dtype=torch.float16, + ) + # Set conversation mode + clib.default_conversation = clib.conv_templates[args.conv_mode].copy() + + # Prepare multi-modal prompt + prompt = [] + if args.media is not None: + for media in args.media or []: + if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]): + media = Sound(media) + else: + raise ValueError(f"Unsupported media type: {media}") + prompt.append(media) + if args.text is not None: + prompt.append(args.text) + + # Generate response + response = model.generate_content(prompt, response_format=response_format) + print(colored(response, "cyan", attrs=["bold"])) + + +if __name__ == "__main__": + main() diff --git a/llava/constants.py b/llava/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e557b88bf7437b49bb88c754d71218ed79dfae --- /dev/null +++ b/llava/constants.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +DEFAULT_SOUND_TOKEN = "" +DEFAULT_SPEECH_TOKEN = "" +SENTINEL_TOKEN = "" + +MEDIA_TOKENS = { + "speech": "", + "sound": "", +} + + +""" +151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151651: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151652: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), + +""" +NUM_EXTRA_TOKENS = 10 diff --git a/llava/conversation.py b/llava/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..5b42e5bf88571231b262040d3601e63006e125e9 --- /dev/null +++ b/llava/conversation.py @@ -0,0 +1,197 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +import dataclasses +from enum import Enum, auto +from typing import List + +from llava.utils.logging import logger + + +class SeparatorStyle(Enum): + """Different separator style.""" + + AUTO = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + sep_style: SeparatorStyle = SeparatorStyle.AUTO + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + ret = self.system + self.sep + for rid, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + sep = self.sep if rid < len(messages) - 1 else self.sep2 + ret += role + message + sep + else: + ret += role + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + +conv_auto = Conversation( + system="", + roles=("", ""), + messages=(), + sep_style=SeparatorStyle.AUTO, + sep="\n", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=(), + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +hermes_2 = Conversation( + system="<|im_start|>system\nAnswer the questions.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", + messages=(), + version="hermes-2", +) + +# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. +llama_3_chat = Conversation( + system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), + version="llama_v3", + messages=(), + sep_style=SeparatorStyle.LLAMA_3, + sep="<|eot_id|>", + sep2="<|end_of_text|>", +) + + +default_conversation = conv_auto +conv_templates = { + "auto": conv_auto, + "hermes-2": hermes_2, + "llama_3": llama_3_chat, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "plain": conv_llava_plain, +} + + +CONVERSATION_MODE_MAPPING = { + "vila1.5-3b": "vicuna_v1", + "vila1.5-8b": "llama_3", + "vila1.5-13b": "vicuna_v1", + "vila1.5-40b": "hermes-2", + "llama-3": "llama_3", + "llama3": "llama_3", +} + + +def auto_set_conversation_mode(model_name_or_path: str) -> str: + global default_conversation + for k, v in CONVERSATION_MODE_MAPPING.items(): + if k in model_name_or_path.lower(): + logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.") + default_conversation = conv_templates[v] + return diff --git a/llava/data/__init__.py b/llava/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22d60dab57c08f90e224f3d95047c5dfbf14de6d --- /dev/null +++ b/llava/data/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .builder import * +from .dataset import * +from .datasets_mixture import * diff --git a/llava/data/base.py b/llava/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..925a500cf78667b3ab862a72ab1b564e86c2c746 --- /dev/null +++ b/llava/data/base.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import random +from typing import Any, Dict, List + +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images +from llava.train.args import DataArguments +from llava.utils.logging import logger +from llava.utils.media import extract_media +from llava.utils.tokenizer import preprocess_conversation + +__all__ = ["BaseDataset"] + +def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(speech) + +def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(sound) + +def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(sound_masks) + + +class BaseDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + no_system_prompt: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.data_args = data_args + self.no_system_prompt = no_system_prompt + self.instances = [] + self.enable_dynamic_res = False + self.enable_dynamic_res_s2 = False + # global_batch_size: int, + self.global_batch_size = kwargs.get("global_batch_size", 1) + + # by default, dataset cls will resample on failure + self.resample_on_failure = kwargs.get("resample_on_failure", True) + + # by default, dataset cls will resample on failure + self.resample_on_failure = kwargs.get("resample_on_failure", True) + + def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]: + raise NotImplementedError + + def __getitem__(self, index: int) -> Dict[str, Any]: + instance = self.instances[index] + + try: + # Process instance to conversation + conversation = self.process(instance) + + # Extract media from conversation + media, media_meta = extract_media(conversation, self.data_args) + + if "speech" in media: + processed_speech = _process_speech(media["speech"], self.data_args) + if "sound" in media: + processed_sound = _process_sound(media["sound"], self.data_args) + processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args) + processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args) + # Prepare "input_ids" and "labels" for training + data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt) + + if "speech" in media: + data["speech"] = processed_speech + if "sound" in media: + data["sound"] = processed_sound + data["sound_feature_masks"] = processed_sound_feature_masks + data["sound_embed_masks"] = processed_sound_embed_masks + + except Exception as e: + if not self.resample_on_failure: + raise e + else: + logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.") + return self.__getitem__(random.randint(0, len(self.instances) - 1)) + + return data + + def __len__(self) -> int: + return len(self.instances) diff --git a/llava/data/builder.py b/llava/data/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..871d380c089e5948a58a668846ecf96faaee4afa --- /dev/null +++ b/llava/data/builder.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import os +import os.path as osp +from itertools import chain +from typing import Any, List, Optional + +import torch +import torch.distributed as dist +from hydra.utils import instantiate +from torch.utils.data import ConcatDataset, Dataset +from transformers import PreTrainedTokenizer + +from llava.data.datasets_mixture import DATASETS_LEGACY +from llava.train.args import DataArguments, TrainingArguments +from llava.utils import io +from llava.utils.logging import logger +import time +import numpy as np +__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"] + + +def load_dataset_yaml(name): + fname = f"{name}.yaml" if not name.endswith(".yaml") else name + + # yaml under llava/data/registry/datasets + repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname) + if osp.exists(repo_path): + return repo_path + + # # yaml under + abs_path = osp.expanduser(fname) + if osp.exists(abs_path): + return abs_path + + raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.") + + +def register_datasets(name: Optional[str] = None): + if name is None: + name = os.environ.get("VILA_DATASETS", "default") + logger.info(f"Registering datasets from environment: '{name}'.") + # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml")) + dataset_meta = {} + for _name in name.split(","): + yamlpath = load_dataset_yaml(_name) + logger.info(f"Registering datasets from: '{yamlpath}'.") + meta = io.load(yamlpath) + dataset_meta.update(meta) + return dataset_meta + + +def register_mixtures(): + return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml")) + + +DATASETS = register_datasets() +MIXTURES = register_mixtures() + + +def parse_mixture(mixture: str) -> List[str]: + names = mixture.split("+") if "+" in mixture else [mixture] + while any(name in MIXTURES for name in names): + names = list(chain(*[MIXTURES.get(name, [name]) for name in names])) + return sorted(names) + + +class SubsetDataset(Dataset): + def __init__(self, dataset: Dataset, limit: int) -> None: + super().__init__() + self.dataset = dataset + self.limit = limit + + def __len__(self) -> int: + return int(len(self.dataset) * self.limit) + + def __getitem__(self, index: int) -> Any: + return self.dataset[index % len(self.dataset)] + +class RepeatedDataset(Dataset): + def __init__(self, dataset: Dataset, times: int) -> None: + super().__init__() + self.dataset = dataset + self.times = times + + def __len__(self) -> int: + return len(self.dataset) * self.times + + def __getitem__(self, index: int) -> Any: + return self.dataset[index % len(self.dataset)] + + +def get_world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def build_dataset( + mixture: str, + data_args: DataArguments, + training_args: TrainingArguments, + tokenizer: PreTrainedTokenizer, +) -> Dataset: + logger.warning(f"Training VILA with mixture '{mixture}'.") + datasets = [] + dataset_rng = np.random.default_rng(1234) + for name in parse_mixture(mixture): + + if "*" in name: + name, times = name.split("*") + times = int(times) + else: + times = 1 + limit_dataset = False + if "#" in name: + # we limit the max length of this dataset + name, max_length_percent = name.split("#") + limit_dataset = True + if DATASETS is not None and name in DATASETS: + if name in DATASETS_LEGACY: + logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.") + dataset = instantiate(DATASETS[name], _partial_=True)( + tokenizer=tokenizer, + data_args=data_args, + global_batch_size=( + training_args.per_device_train_batch_size + # * torch.distributed.get_world_size() + * get_world_size() + * training_args.gradient_accumulation_steps + ), + ) + elif name in DATASETS_LEGACY: + logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.") + dataset = build_dataset_legacy( + name, + data_args=data_args, + training_args=training_args, + tokenizer=tokenizer, + ) + else: + raise ValueError(f"Dataset '{name}' is not found in the registries.") + + + if limit_dataset: + # we limit the max length of this dataset + max_length = int(float(int(max_length_percent) / 100.) * len(dataset)) + dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.)) + + if times > 1: + dataset = RepeatedDataset(dataset, times) + datasets.append(dataset) + return ConcatDataset(datasets) + + +def build_dataset_legacy( + name: str, + data_args: DataArguments, + training_args: TrainingArguments, + tokenizer: PreTrainedTokenizer, +) -> Dataset: + from llava.data.dataset import ( + LazySupervisedDataset, + LazyWDSDataset, + ) + + dataset = DATASETS_LEGACY[name] + dataset_type = dataset.dataset_type + if dataset_type == "torch": + dataset_cls = LazySupervisedDataset + elif dataset_type == "wds": + dataset_cls = LazyWDSDataset + else: + raise NotImplementedError(f"{dataset_type} is not supported.") + + data_args.meta_path = getattr(dataset, "meta_path", None) + data_args.caption_choice = getattr(dataset, "caption_choice", None) + data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None) + data_args.start_idx = getattr(dataset, "start_idx", None) + data_args.end_idx = getattr(dataset, "end_idx", None) + + return dataset_cls( + tokenizer=tokenizer, + data_path=dataset.data_path, + image_folder=getattr(dataset, "image_path"), + data_args=data_args, + training_args=training_args, + ) diff --git a/llava/data/collate.py b/llava/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..4795b77a246acce9294f07b189f0a06725b71f2c --- /dev/null +++ b/llava/data/collate.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass +from typing import Any, Dict, Sequence + +import torch +from transformers import PreTrainedTokenizer + +from llava.constants import IGNORE_INDEX +from llava.utils.logging import logger + +__all__ = ["DataCollator"] + + +@dataclass +class DataCollator: + tokenizer: PreTrainedTokenizer + + def __init__(self, tokenizer: PreTrainedTokenizer): + super().__init__() + self.tokenizer = tokenizer + + def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + # Gather everything from the batch + input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, [] + + media_meta = {} + + media_meta["sound_feature_masks"] = [] + media_meta["sound_embed_masks"] = [] + media_meta["frame_times"] = [] + for instance in instances: + if isinstance(instance["input_ids"], torch.Tensor): + input_ids.append(instance["input_ids"]) + labels.append(instance["labels"]) + for name in media: + objs = instance.get(name) + objs = objs if objs is not None else [] + media[name].append([obj for obj in objs]) + if instance.get("sound") is not None: + for name_k in media_meta: + if "sound" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if instance.get("video") is not None or instance.get("image") is not None: + for name_k in media_meta: + if "frame" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if "block_sizes" in instance: + block_sizes.append(instance["block_sizes"]) + else: + block_sizes.append( + [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else [] + ) + else: + input_ids.extend(instance["input_ids"]) + labels.extend(instance["labels"]) + for name in media: + objs = instance.get(name) + objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))] + media[name].extend(objs) + if instance.get("sound") is not None: + for name_k in media_meta: + if "sound" in name_k: + objs = instance.get(name_k) + media_meta[name_k].extend(objs) + if instance.get("video") is not None or instance.get("image") is not None: + for name_k in media_meta: + if "frame" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if "block_sizes" in instance: + block_sizes.extend(instance["block_sizes"]) + else: + block_sizes.extend( + [[None for _ in range(len(objs))] for objs in instance.get("image")] + if instance.get("image") is not None + else [[] for _ in range(len(instance["input_ids"]))] + ) + + batch_size = len(input_ids) + + + # Check if the number of media objects (or the number of block sizes) matches the number of media tokens + for name in media: + for k in range(batch_size): + if name == "image" and not all([_ is None for _ in block_sizes[k]]): + actual = len(block_sizes[k]) + else: + actual = len(media[name][k]) + expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + if actual != expected: + raise ValueError( + f"Number mismatch between {name} objects and {name} tokens. " + f"There are {expected} {name} tokens but {actual} {name} objects." + ) + + # Batchify the inputs + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, + batch_first=True, + padding_value=IGNORE_INDEX, + ) + input_ids = input_ids[:, : self.tokenizer.model_max_length] + labels = labels[:, : self.tokenizer.model_max_length] + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) + + # Truncate media objects if necessary + for name in media: + objects = [] + for k in range(batch_size): + if name == "image" and not all([_ is None for _ in block_sizes[k]]): + actual = len(media[name][k]) + num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]]) + num_small_scale_blocks = actual - num_large_scale_blocks + num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k]) + expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + expected = ( + sum([x * y for x, y in block_sizes[k][:expected_full_image]]) + + num_small_scale_blocks_each_img * expected_full_image + ) + if actual > expected: + logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") + media[name][k] = media[name][k][:expected] + objects.extend(media[name][k]) + block_sizes[k] = block_sizes[k][:expected_full_image] + else: + actual = len(media[name][k]) + expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + if actual > expected: + logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") + media[name][k] = media[name][k][:expected] + objects.extend(media[name][k]) + if name == "image": + block_sizes[k] = block_sizes[k][:expected] + media[name] = objects + + for name in media_meta: + objects = [] + for k in range(batch_size): + try: + objects.extend(media_meta[name][k]) + except: + continue + media_meta[name] = objects + + # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...] + block_sizes = sum(block_sizes, []) + return { + "input_ids": input_ids, + "media": media, + "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}}, + "labels": labels, + "attention_mask": attention_mask, + "media_meta": media_meta, + } diff --git a/llava/data/dataset.py b/llava/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e549c6c66fe7db2ce2173f078cdf2d70f291cfd8 --- /dev/null +++ b/llava/data/dataset.py @@ -0,0 +1,1635 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import copy +import io +import json +import os +import os.path as osp +import random +import time +import warnings +from dataclasses import dataclass +from typing import Dict, Sequence +import math +import numpy as np +import PIL +import torch +import transformers +from PIL import Image, ImageFile +from torch.utils.data import Dataset, default_collate +from transformers import PreTrainedTokenizer +from transformers import AutoFeatureExtractor +import kaldiio +import llava.data.datasets_mixture as datasets_mixture +from llava import conversation as conversation_lib +from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX +from llava.data.collate import DataCollator +from llava.mm_utils import ( + load_audio, + get_num_windows, + tokenizer_image_token, +) +from torchvision import transforms +from llava.train.args import DataArguments, TrainingArguments +from llava.train.sequence_parallel import ( + extract_local_from_list, + extract_local_input_ids, + extract_local_position_ids, + get_pg_manager, +) +from llava.utils.tokenizer import preprocess_conversation +# import torchaudio +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +import soundfile as sf +from librosa import resample as librosa_resample +import whisper +ImageFile.LOAD_TRUNCATED_IMAGES = True +PIL.Image.MAX_IMAGE_PIXELS = 1000000000 + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + concat_values = "".join([sentence["value"] for sentence in source]) + for sid, sentence in enumerate(source): + # In multimodal conversations, we automatically prepend '' at the start of the first sentence if it doesn't already contain one. + + if DEFAULT_SOUND_TOKEN in sentence["value"]: + sentence["value"] = sentence["value"].replace(DEFAULT_SOUND_TOKEN, f"{DEFAULT_SOUND_TOKEN}\n") + sentence["value"] = sentence["value"].replace(f"{DEFAULT_SOUND_TOKEN}\n\n", f"{DEFAULT_SOUND_TOKEN}\n") + if DEFAULT_SPEECH_TOKEN in sentence["value"]: + sentence["value"] = sentence["value"].replace(DEFAULT_SPEECH_TOKEN, f"{DEFAULT_SPEECH_TOKEN}\n") + sentence["value"] = sentence["value"].replace(f"{DEFAULT_SPEECH_TOKEN}\n\n", f"{DEFAULT_SPEECH_TOKEN}\n") + return sources + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = DEFAULT_IMAGE_TOKEN + conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + no_system_prompt: bool = False, +) -> Dict: + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + return default_collate( + [ + preprocess_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt) + for conversation in sources + ] + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is originally implemented by the LLaVA team and modified by + Ji Lin and Haotian Tang. + """ + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + ): + super().__init__() + try: + with open(data_path) as fp: + list_data_dict = json.load(fp) + except: + with open(data_path) as fp: + list_data_dict = [json.loads(q) for q in fp] + + # rank0_print("Formatting inputs...Skip in lazy mode") + print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + self.image_folder = image_folder + self.wav_processor = AutoFeatureExtractor.from_pretrained('Qwen/Qwen2-Audio-7B') + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if "image" in sample else 0 + length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + if 'duration' in sample.keys(): + duration = sample["duration"] + else: + duration = 10. + try: + cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + int(math.ceil(duration * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + except: + try: + cur_len = 0 + int(math.ceil(duration * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + except: + cur_len = 0 + int(math.ceil(10. * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + return length_list + + @staticmethod + def _load_sound(sound_file, wav_processor, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=3, audio_start = 0.0): + if sound_file is None: + return None + window_length = int(window_length * sample_rate) + window_overlap = int(window_overlap * sample_rate) + max_num_window = int(max_num_window) + duration = max_num_window * (window_length - window_overlap) + window_overlap + + sound_outputs = [] + audio_feature_masks = [] + audio_embed_masks = [] + + try: + sound_filename = str.split(sound_file, '/')[-1] + if '.ark' in sound_filename: + sound = kaldiio.load_mat(sound_file) + audio_data = sound[1] + audio_data=audio_data.astype(np.float16) + else: + audio_data = load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration + T = len(audio_data) + audio_data = audio_data.reshape(1, -1) + num_windows, full_length = get_num_windows(T, sample_rate, max_num_window) + + audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() + + for i in range(num_windows): + audio_embed_mask = torch.zeros(750) + start = i * (window_length - window_overlap) + audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] + orig_length = audio_data_tensor_this.shape[1] + audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") # + sound_outputs.append(audio_data_tensor_this["input_features"]) + # calculate the mask for the input melspec to Whisper + melspec_frames_this_window = int(math.ceil(orig_length / 160)) + feature_attention_mask = torch.zeros(3000, dtype=torch.int32) + feature_attention_mask[:melspec_frames_this_window] = 1 + audio_feature_masks.append(feature_attention_mask.unsqueeze(0)) + # calculate the mask for the output embedding for use in AF2 + conv_lengths = (melspec_frames_this_window - 1) // 2 + 1 + output_embedding_lengths = (conv_lengths - 2) // 2 + 1 + audio_embed_mask[:output_embedding_lengths] = 1 + audio_embed_masks.append(audio_embed_mask) + except: + print('error loading file', sound_file) + sound_outputs.append(torch.zeros(1,128,3000)) + audio_feature_masks.append(torch.zeros(1,3000, dtype=torch.int32)) + audio_embed_masks.append(torch.zeros(750)) + + return torch.stack(sound_outputs, dim=0), torch.stack(audio_feature_masks, dim=0), torch.stack(audio_embed_masks, dim=0) + + @staticmethod + def _load_speech(speech_path,sample_rate=16000): + if speech_path is None: + return None + + speech_outputs = [] + try: + speech = whisper.load_audio(speech_path) + speech = whisper.pad_or_trim(speech) + mel = whisper.log_mel_spectrogram(speech) + speech_outputs.append(mel.unsqueeze(0)) + except: + speech_outputs.append(torch.zeros(1,80,3000)) + return torch.stack(speech_outputs, dim=0) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + import re + if "sound" in self.list_data_dict[i]: + # chat data loading + if isinstance(self.list_data_dict[i]["sound"],list): + sound_files = self.list_data_dict[i]["sound"] + conversations_raw = self.list_data_dict[i]["conversations"] + + # Step 1: Extract tags in order of appearance + sound_tag_pattern = re.compile(r"") + ordered_sound_tags = [] + + for turn in conversations_raw: + tags = sound_tag_pattern.findall(turn["value"]) + ordered_sound_tags.extend([f"" for tag in tags]) + + # Step 2: Load sound tensors in the order of tags + sound_tensor = [] + audio_feature_masks = [] + audio_embed_masks = [] + sound_token_map = {} + + for tag in ordered_sound_tags: + idx = int(tag.split('-')[1][:-1]) + if tag not in sound_token_map: + this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + this_sound_tensor = this_sound_tensor.squeeze(1) # (windows x 750 x 2048) + sound_token_map[tag] = ("\n" * this_sound_tensor.shape[0]).rstrip() + sound_tensor.append(this_sound_tensor) + audio_feature_masks.append(af_mask) + audio_embed_masks.append(ae_mask) + else: + # If already loaded, still append to match sequence + this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + this_sound_tensor = this_sound_tensor.squeeze(1) + sound_tensor.append(this_sound_tensor) + audio_feature_masks.append(af_mask) + audio_embed_masks.append(ae_mask) + + + # Process conversations and inject sound markers + conversation = [] + for turn in conversations_raw: + role = turn["from"] + value = turn["value"] + + # Replace any tag with corresponding repeated \n + for tag, sound_token in sound_token_map.items(): + value = value.replace(tag, sound_token) + + conversation.append({ + "from": role, + "value": value.rstrip() + }) + + sources = [conversation] + sound_tensor = torch.cat(sound_tensor, dim=0) + audio_feature_masks = torch.cat(audio_feature_masks, dim=0) + audio_embed_masks = torch.cat(audio_embed_masks, dim=0) + else: + sound_file = self.list_data_dict[i]["sound"] + question = str(self.list_data_dict[i]["conversations"][0]["value"].rstrip()) + answer = str(self.list_data_dict[i]["conversations"][1]["value"]).rstrip() + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + sound_tensor, audio_feature_masks, audio_embed_masks = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + sound_tensor=sound_tensor.squeeze(1) # squeeze the irrelevant dimension which was caused due to processor getting 1 batch for processing --> (windows x 750 x 2048) + question = "\n" * sound_tensor.shape[0] + question + conversation = [ + {"from": "human", "value": question}, + {"from": "gpt", "value": answer}, + ] + + sources = [conversation] + data_dict = preprocess( + sources, + self.tokenizer, + has_image=( + "sound" in self.list_data_dict[i] + ), + ) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + if "sound" in self.list_data_dict[i]: + data_dict["sound"] = sound_tensor + data_dict["sound_feature_masks"] = audio_feature_masks + data_dict["sound_embed_masks"] = audio_embed_masks + if "speech" in self.list_data_dict[i]: + data_dict["speech"] = speech_tensor + + return data_dict + + +class LazyMMC4Dataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Haotian Tang.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + image_following_text_only=False, + text_only=False, + ): + super().__init__() + + import pickle + + n_samples = [] + # actually shards and stats info + n_shards = len(os.listdir(data_path)) // 2 + # n_shards = 100 + count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards] + n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list] + + print("total MMC4 samples", sum(n_samples)) # 10,881,869 + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + + gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + shard_names = [d.replace(".count", ".pkl") for d in count_info_list] + shard_names = shard_names[shard_start:shard_end] + + full_data_list = [] + # now load data + for shard_name in shard_names: + # load shard + with open(os.path.join(data_path, shard_name), "rb") as f: + data_list = pickle.load(f) + + full_data_list.extend(data_list) + + print(f"* loaded totally {len(full_data_list)} samples") + + self.data_list = full_data_list + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + self.image_following_text_only = image_following_text_only + self.text_only = text_only + + def __len__(self): + # return len(self.data_list) + return self.n_samples + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for info in self.data_list: + num_images = min(6, len(info["image_info"])) + sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]] + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = num_images * self.num_image_tokens // 2 + sum([len(x) for x in sentences]) + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + info = self.data_list[i - self.idx_offset] + + sentences = info["text_list"] + # kentang-mit@: remove existing tokens in the sentences + for ix in range(len(sentences)): + # if this is an html tag, we still preserve its semantic meaning + sentences[ix] = sentences[ix].replace("", "") + sim_matrix = info["similarity_matrix"] # we do not use this... + + # convert images from base64 to PIL and filter based on image-text similarity + images, sentence_ixs = [], [] + if not self.text_only: + for sample_image, sim_vec in zip(info["image_info"], sim_matrix): + image_base64 = sample_image["image_base64"] + rawbytes = base64.b64decode(image_base64) + + sim_ix = sample_image["matched_text_index"] + # sim_ix = np.argmax(sim_vec) + # sim_score = sim_vec[sim_ix] + + # filter to images >= 5KB + # if len(rawbytes) // 1000 <= 5: + # continue + # if sim_score < 0.24: + # continue + image = Image.open(io.BytesIO(rawbytes)).convert("RGB") + + images.append(image) + sentence_ixs.append(sim_ix) + + # constrain max num 6 images + max_num_images = 6 + if len(images) > max_num_images: + images = images[:max_num_images] + sentence_ixs = sentence_ixs[:max_num_images] + + # reorder images according to text insertion + images = [images[iii] for iii in np.argsort(sentence_ixs)] + + # preprocess and tokenize text + for ix in sentence_ixs: + sentences[ix] = f"\n{sentences[ix]}" + + if self.image_following_text_only: + # use pad tokens to divide sentence pieces + text = self.tokenizer.pad_token.join(sentences) + else: + text = " ".join(sentences) + # whitespace cleanup + text = text.replace(" ", "").replace(" ", "") + text = f"{text}{self.tokenizer.eos_token}" # add eos token + + if len(images) > 0: + if self.data_args.image_aspect_ratio == "dynamic_s2": + images, block_sizes = dynamic_s2_process_images_and_prompt( + images, text, self.data_args, self.image_folder + ) + elif self.data_args.image_aspect_ratio == "dynamic": + images, text = dynamic_process_images_and_prompt( + images, text, self.data_args, self.image_folder, max_tiles=6 + ) + else: + images = torch.stack([process_image(image, self.data_args, self.image_folder) for image in images]) + + # the same size for all images, so we concat + # cur_token_len = ( + # images[0].shape[-2] // self.multimodal_cfg["patch_size"] + # ) * (images[0].shape[-1] // self.multimodal_cfg["patch_size"]) + # cur_token_len += self.multimodal_cfg["n_extra_patch"] + else: + images = None + # cur_token_len = 0 + + input_ids = tokenizer_image_token( + text, + self.tokenizer, + return_tensors="pt", + ) + + image_token_id = self.tokenizer.media_token_ids["image"] + + # now check the case where the last token is image patch token + if input_ids[-1] == image_token_id: # need to remove one last image + last_non_im_patch_indices = torch.where(input_ids != image_token_id)[0][-1] + 1 + input_ids = input_ids[:last_non_im_patch_indices] + + n_im_patch = (input_ids == image_token_id).sum().item() + + if self.data_args.image_aspect_ratio != "dynamic_s2": + images = images[:n_im_patch] + assert len(images) == n_im_patch, print(text, input_ids) + assert len(input_ids.shape) == 1, "Unexpected shape of 'input_ids' from MMC4." + input_ids = ( + torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids]) + if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id + else input_ids + ) + targets = input_ids.clone() + + if self.image_following_text_only: # keep only text after leading image token + # remove loss for any token before the first token + label_idx = 0 + while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id: + targets[label_idx] = IGNORE_INDEX + label_idx += 1 + + pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0] + + pad_token_idxs = torch.where(targets == pad_token)[0] + for pad_token_idx in pad_token_idxs: + token_idx = pad_token_idx + 1 + while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id: + targets[token_idx] = IGNORE_INDEX + token_idx += 1 + # do not train on padding tokens + targets[targets == pad_token] = IGNORE_INDEX + + # mask image tokens is unnecessary for llava-1.5 + # targets[targets == IMAGE_TOKEN_INDEX] = IGNORE_INDEX + # print(input_ids.shape) + + data_dict = dict(input_ids=input_ids, labels=targets, image=images) + if self.data_args.image_aspect_ratio == "dynamic_s2": + data_dict["block_sizes"] = block_sizes + + return data_dict + + +class LazyCoyoDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Haotian Tang.""" + + num_image_tokens = 576 + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # kentang-mit@: balance the total number of tokens for Coyo and MMC4. + n_samples_per_idx=4, + ): + super().__init__() + + import pickle + + n_samples = [] + # actually shards and stats info + n_shards = len(os.listdir(data_path)) // 2 + # n_shards = 100 + count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards] + n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list] + + print("total COYO samples", sum(n_samples)) + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + + gpu_samples = [ + sum(n_samples[i * shared_size : (i + 1) * shared_size]) // n_samples_per_idx for i in range(world_size) + ] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + shard_names = [d.replace(".count", ".pkl") for d in count_info_list] + shard_names = shard_names[shard_start:shard_end] + + full_data_list = [] + # now load data + for shard_name in shard_names: + # load shard + with open(os.path.join(data_path, shard_name), "rb") as f: + shard_data = pickle.load(f) + random.seed(42) + if "mmc4" in data_path: + random.shuffle(shard_data) # shuffle for MMC4cap only + full_data_list.extend(shard_data) + + print(f"* loaded totally {len(full_data_list)} samples") + + # now pack the samples into groups + n_groups = len(full_data_list) // n_samples_per_idx + full_data_list = [ + full_data_list[i : i + n_samples_per_idx] for i in range(0, len(full_data_list), n_samples_per_idx) + ] + if len(full_data_list[-1]) < n_samples_per_idx: + full_data_list = full_data_list[:-1] + assert len(full_data_list) == n_groups + print(f"split into {n_groups} groups") + + self.data_list = full_data_list + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + def __len__(self): + # return len(self.data_list) + return self.n_samples + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + CONCAT_SAMPLES = False + info_list = self.data_list[i - self.idx_offset] + + text_list = [] + image_list = [] + + for sample in info_list: + caption_key = ( + "text" if "text" in sample else "caption" + ) # kentang-mit@: remove existing tokens in the sentences + # kentang-mit@: remove existing token. + # if this is an html tag, we still preserve its semantic meaning + sample[caption_key] = sample[caption_key].replace("", "") + text_list.append(DEFAULT_IMAGE_TOKEN + "\n" + sample[caption_key] + self.tokenizer.eos_token) + if "image" in sample: + image_base64 = sample["image"] + rawbytes = base64.b64decode(image_base64) + else: + rawbytes = sample["rawbytes"] + image = Image.open(io.BytesIO(rawbytes)).convert("RGB") + image_list.append(image) + + image_list = torch.stack([process_image(image, self.data_args, self.image_folder) for image in image_list]) + + if CONCAT_SAMPLES: + # into capcap... + text_list = "".join(text_list) + + input_ids = self.tokenizer( + text_list, + return_tensors="pt", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).input_ids # 4, seq_len + + input_ids = input_ids[0] + + else: + input_ids = [ + tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + for prompt in text_list + ] + # print([x.shape[0] for x in input_ids], [len(x.split()) for x in text_list], [len(re.findall(r"]*>", x)) for x in text_list]) + + # input_ids = torch.nn.utils.rnn.pad_sequence( + # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + # ) + + targets = copy.deepcopy(input_ids) + for i in range(len(targets)): + targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets, image=image_list) + + +class LazyWDSDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Ligeng Zhu.""" + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + image_folder: str, + training_args: TrainingArguments, + ): + super().__init__() + n_samples = [] + n_shards = len(os.listdir(data_path)) // 3 + for shard in range(n_shards): + with open(os.path.join(data_path, f"{shard:05d}_stats.json")) as f: + info = json.load(f) + n_samples.append(info["successes"]) + + # print(f"[DEBUG] {data_path} total samples", sum(n_samples)) # 10,881,869 + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + print("rank", rank, "world_size", world_size, "shared_size", shared_size) + gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + tar_list = [f"{shard_idx:05d}.tar" for shard_idx in range(shard_start, shard_end)] + + self.data_list = [] + t1 = time.time() + for tar in tar_list: + tmp_path = f"/tmp/ccs{tar}" + tar_path = os.path.join(data_path, tar) + + if PROCESS_GROUP_MANAGER is not None: + dist.barrier() + if PROCESS_GROUP_MANAGER.sp_rank == 0: + os.makedirs(tmp_path, exist_ok=True) + os.system(f"tar -xkf {tar_path} -C {tmp_path}") + dist.barrier() + else: + os.makedirs(tmp_path, exist_ok=True) + os.system(f"tar -xkf {tar_path} -C {tmp_path}") + + txt_list = [f for f in os.listdir(tmp_path) if f.endswith(".txt")] + + for txt in txt_list: + caption = open(os.path.join(tmp_path, txt)).read().strip() + image_path = os.path.join(tmp_path, txt.split(".")[0] + ".jpg") + self.data_list.append({"caption": caption, "image": image_path}) + t2 = time.time() + print(f"Loading done. Total time: {t2 - t1:.2f} seconds") + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + def __len__(self): + return self.n_samples + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + + # print("i", i, "idx_offset", self.idx_offset, "len", len(self.data_list)) + info = self.data_list[i - self.idx_offset] + caption, image_path = info["caption"], info["image"] + + rand_prompt = "\n" + sources = [ + { + "image": image_path, + "conversations": [ + {"from": "human", "value": rand_prompt}, + {"from": "gpt", "value": caption}, + ], + } + ] + + # one example of sources + # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}] + if "image" in sources[0]: + image = process_image(sources[0]["image"], self.data_args, self.image_folder) + image = torch.unsqueeze(image, dim=0) + # now random pick some context samples for training + if hasattr(self.data_args, "num_shots"): + if self.data_args.num_shots > 0: + raise NotImplementedError + else: + raise NotImplementedError + + data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True) + + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + # image exist in the data + if image is not None: + data_dict["image"] = image + else: + raise NotImplementedError + + return data_dict + + +class LazyCCSWebDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ligeng Zhu.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + ): + super().__init__() + t1 = time.time() + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset(data_path=osp.abspath(data_path)) + + t2 = time.time() + print(f"Loading done. Total time: {t2 - t1:.2f} seconds") + + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + # info = self.data_list[i - self.idx_offset] + # caption, image_path = info["caption"], info["image"] + info = self.dataset[i] + if ".jpg" in info: + caption, image_path = info[".txt"], info[".jpg"] + elif ".png" in info: + caption, image_path = info[".txt"], info[".png"] + elif ".webp" in info: + caption, image_path = info[".txt"], info[".webp"] + elif ".bmp" in info: + caption, image_path = info[".txt"], info[".bmp"] + elif ".tiff" in info: + caption, image_path = info[".txt"], info[".tiff"] + else: + print(info.keys()) + print(info) + raise KeyError + + caption = caption.replace("", "") + if isinstance(image_path, io.BytesIO): + image_path = Image.open(image_path).convert("RGB") + + if not isinstance(image_path, PIL.Image.Image): + print(image_path) + print(info.keys()) + print(type(image_path)) + raise NotImplementedError + + rand_prompt = "\n" + sources = [ + { + "image": image_path, + "conversations": [ + {"from": "human", "value": rand_prompt}, + {"from": "gpt", "value": caption}, + ], + } + ] + + # one example of sources + # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}] + if "image" in sources[0]: + image = process_image(sources[0]["image"], self.data_args, image_folder=None) + image = torch.unsqueeze(image, dim=0) + # now random pick some context samples for training + if hasattr(self.data_args, "num_shots"): + if self.data_args.num_shots > 0: + raise NotImplementedError + else: + raise NotImplementedError + + data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True) + + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + # image exist in the data + if image is not None: + data_dict["image"] = image + else: + raise NotImplementedError + + return data_dict + + +from functools import lru_cache + + +@lru_cache(maxsize=16) +def lru_json_load(fpath): + with open(fpath) as fp: + return json.load(fp) + + +class LazyCoyoWebDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ligeng Zhu.""" + + num_image_tokens = 576 + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # kentang-mit@: balance the total number of tokens for Coyo and MMC4. + n_samples_per_idx=4, + ): + super().__init__() + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset(data_path=osp.abspath(data_path), meta_path=data_args.meta_path) + + if data_args.start_idx >= 0 and data_args.end_idx >= 0: + # Ligeng: support slicing for ablate different subsets. + total = len(self.dataset) + start_idx = int(total * data_args.start_idx) + end_idx = int(total * data_args.end_idx) + print(f"loading subset from {start_idx} to {end_idx}, total {total}") + self.dataset = torch.utils.data.Subset(self.dataset, range(start_idx, end_idx)) + + # For caption choice, + # if None: use original caption + # if a folder path: use specified caption to override original one (choice1) + # if a folder path: use specified caption and concat with original one (choice2) + self.caption_choice = None + self.caption_choice_2 = None + self.data_path = data_path + + if data_args.caption_choice is not None: + self.caption_choice = data_args.caption_choice + print("[recap] Override coyo caption using ", self.caption_choice) + + if data_args.caption_choice_2 is not None: + self.caption_choice_2 = data_args.caption_choice_2 + print("[recapv2] Override coyo caption using ", self.caption_choice_2) + + print("total samples", len(self.dataset)) + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = ( + training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2 + ) # int(os.environ["RANK"]) + world_size = ( + training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32 + ) # int(os.environ["WORLD_SIZE"]) + print( + "rank", + rank, + "world_size", + world_size, + ) + + self.n_samples_per_idx = n_samples_per_idx + # self.n_samples = len(self.dataset) // n_samples_per_idx + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.dataset) // self.n_samples_per_idx + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + CONCAT_SAMPLES = False + # info_list = self.dataset[i - self.idx_offset] + + begin_idx, end_idx = ( + i * self.n_samples_per_idx, + (i + 1) * self.n_samples_per_idx, + ) + end_idx = min(end_idx, len(self.dataset)) + + text_list = [] + image_list = [] + + for idx in range(begin_idx, end_idx): + info = self.dataset[idx] + if ".jpg" in info: + caption, image_path = info[".txt"], info[".jpg"] + elif ".png" in info: + caption, image_path = info[".txt"], info[".png"] + elif ".webp" in info: + caption, image_path = info[".txt"], info[".webp"] + elif ".bmp" in info: + caption, image_path = info[".txt"], info[".bmp"] + elif ".tiff" in info: + caption, image_path = info[".txt"], info[".tiff"] + else: + print(info.keys()) + print(info) + raise KeyError + + if self.caption_choice is not None: + # load new captions + shard = info["__shard__"] + url = info[".json"]["url"] + tar_name = osp.relpath(osp.realpath(shard), osp.realpath(self.data_path)) + # tar_name = osp.dirname(shard) + shard_json_path = osp.join(self.caption_choice, tar_name + ".json") + try: + shard_json = lru_json_load(shard_json_path) + try: + caption = shard_json[url]["output"] + except KeyError: + print(f"{url} not in caption. fallback to original caption temporarially") + except: + print(f"shard_json_path {shard_json_path} not found. fallback to original caption temporarially") + caption = caption.replace("", "") + text_list.append(DEFAULT_IMAGE_TOKEN + caption + self.tokenizer.eos_token) + + if isinstance(image_path, io.BytesIO): + image_path = Image.open(image_path).convert("RGB") + + if not isinstance(image_path, PIL.Image.Image): + print(image_path) + print(info.keys()) + print(type(image_path)) + raise NotImplementedError + + image_list.append(image_path) + + # image_list = torch.stack([process_image(image, self.data_args, image_folder=None) for image in image_list]) + # NOTE(fix by ligeng) + # now image_list should return a list of image tensor where each has a dimension of (1, c, h, w) + image_list = [process_image(image, self.data_args, image_folder=None).unsqueeze(0) for image in image_list] + + if CONCAT_SAMPLES: + # into capcap... + text_list = "".join(text_list) + + input_ids = self.tokenizer( + text_list, + return_tensors="pt", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).input_ids # 4, seq_len + + input_ids = input_ids[0] + else: + input_ids = [ + tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + for prompt in text_list + ] + input_ids = [ + ( + torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids_i]) + if input_ids_i[0] != self.tokenizer.bos_token_id + else input_ids_i + ) + for input_ids_i in input_ids + ] + + targets = copy.deepcopy(input_ids) + for i in range(len(targets)): + targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets, image=image_list) + + +class LazyVideoWebDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # cache_path: str, + # n_samples_per_idx=4, + ): + super().__init__() + + # from llava.data.simple_video_dataset import SimpleVideoDataset + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset( + data_path=osp.abspath(data_path), + meta_path=f"{osp.abspath(data_path)}/wids-meta.json", + # cache_dir=cache_path, + ) + + # None: use original caption + # Folder path: use original caption + self.caption_choice = None + self.data_path = data_path + + if data_args.caption_choice is not None: + self.caption_choice = data_args.caption_choice + print("[recap] Override LazyVideo caption using ", self.caption_choice) + + print("total samples", len(self.dataset)) + # InternVid: TODO + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = ( + training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2 + ) # int(os.environ["RANK"]) + world_size = ( + training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32 + ) # int(os.environ["WORLD_SIZE"]) + print( + "rank", + rank, + "world_size", + world_size, + ) + self.rank = rank + # rank = int(os.environ["RANK"]) if "RANK" in os.environ else 2 + # world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 32 + + self.tokenizer = tokenizer + self.data_args = data_args + + self.missing_uids = set() + + def __len__(self): + return len(self.dataset) + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + ADD_TEXT_PROMPT = False + num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8 + loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0 + + info = self.dataset[i] + + caption = "" + # print(info) + if ".mp4" in info: + caption, video_path = info[".txt"], info[".mp4"] + else: + video_path = None + caption = "Empty video." + + images, frames_loaded, _ = LazySupervisedDataset._load_video( + video_path, num_video_frames, loader_fps, self.data_args + ) + + if frames_loaded == 0: + caption = "Empty video." + + if self.caption_choice is not None: + shard = info["__shard__"] + uuid = osp.join(info["__shard__"], info["__key__"]) + url = info["__key__"] + tar_name = osp.basename(info["__shard__"]) + + try: + shard_json_path = osp.join(self.caption_choice, tar_name.replace(".tar", ".json")) + shard_json = lru_json_load(shard_json_path) + caption = shard_json[url]["summary"]["output"] + except (KeyError, FileNotFoundError, json.decoder.JSONDecodeError): + if uuid not in self.missing_uids: + print("override caption not found for ", uuid) + self.missing_uids.add(uuid) + + # print(f"[DEBUG {uuid}]", caption) + + frames_loaded_successfully = len(images) + if caption is None: + caption = "" + prompt = "\n" * frames_loaded_successfully + caption + image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images]) + + input_ids = tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + targets = copy.deepcopy(input_ids) + data_dict = dict(input_ids=input_ids, labels=targets, image=image_tensor) + + return data_dict + + +class DataCollatorForSupervisedDatasetSeqParallel: + """Collate examples for supervised fine-tuning. + This class is originally implemented by the LLaVA team and + modified by Haotian Tang.""" + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + sp_degree: int, + sp_rank: int, + ring_degree: int, + ring_type: str, + ): + self.tokenizer = tokenizer + self.data_args = data_args + self.training_args = training_args + self.sp_degree = sp_degree + self.sp_rank = sp_rank + self.ring_degree = ring_degree + self.ring_type = ring_type + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels, images = [], [], [] + image_token_id = self.tokenizer.media_token_ids["image"] + video_token_id = self.tokenizer.media_token_ids["video"] + + for instance in instances: + if not isinstance(instance["input_ids"], list): + input_ids.append(instance["input_ids"]) + else: + input_ids += instance["input_ids"] + if not isinstance(instance["labels"], list): + labels.append(instance["labels"]) + else: + labels += instance["labels"] + # Note (kentang-mit@: we do not directly push tensors to + # images, but list of tensors. + if "video" in instance: + instance["image"] = torch.cat(instance["video"]) + video_id_pos = torch.where(input_ids[-1] == video_token_id)[0][0] + replace_ids = torch.Tensor( + ([image_token_id] + self.tokenizer.encode("\n")) * instance["image"].shape[0], + device=input_ids[-1].device, + ) + input_ids[-1] = torch.cat( + [input_ids[-1][:video_id_pos], replace_ids, input_ids[-1][video_id_pos + 1 :]] + ).to(input_ids[-1].dtype) + labels[-1] = torch.cat( + [ + labels[-1][:video_id_pos], + torch.Tensor([IGNORE_INDEX] * instance["image"].shape[0] * 2), + labels[-1][video_id_pos + 1 :], + ] + ).to(labels[-1].dtype) + instance.pop("video") + + if "image" in instance: + cur_image = instance["image"] + assert len(cur_image.shape) == 4 + # n_images, 3, size, size + if cur_image.shape[0] == 0: + warnings.warn("loaded one sample without images.") + if not isinstance(instance["input_ids"], list): + # datasets other than coyo, not packing >1 samples together + images.append(cur_image) + else: + # coyo-like datasets + images.extend(cur_image.chunk(cur_image.size(0), 0)) + else: + warnings.warn("loaded one sample without images.") + images.append([]) + # kentang-mit@: we need to make sure these two lists have + # the same length. We will use input_ids to filter out images corresponding + # to truncated tokens later. + + max_num_images = max([len(_images) for _images in images]) + for _images, _input_ids in zip(images, input_ids): + assert ( + len(_images) == (_input_ids == image_token_id).sum().item() + ), f"Number mismatch between images and placeholder image tokens in 'len(_images) == (_input_ids == image_token_id).sum().item()'.\ + Expect to have {len(_images)} images but only found {(_input_ids == image_token_id).sum().item()} images in tokens. \ + Error input_ids: {_input_ids} {self.tokenizer.decode([x if x != -200 else 200 for x in _input_ids])}" + + NUM_TOKENS_PER_IMAGE = self.data_args.num_image_tokens + if hasattr(self.data_args.image_processor, "crop_size"): + crop_size = self.data_args.image_processor.crop_size + else: + crop_size = self.data_args.image_processor.size + + # Init the padding sample + seq_id = 0 + while seq_id < len(input_ids): + # Skip the samples without images + dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device) + # dummy input_ids include one bos, one image token, and one eos + dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3]) + dummy_input_ids[0] = self.tokenizer.bos_token_id + dummy_input_ids[1] = image_token_id + dummy_input_ids[2] = self.tokenizer.eos_token_id + dummy_labels = copy.deepcopy(dummy_input_ids) + dummy_labels[:2] = IGNORE_INDEX + dummy_seqlen = NUM_TOKENS_PER_IMAGE + 2 # TODO: Check the hard coding of 2 + dummy_position_ids = torch.arange(start=0, end=dummy_seqlen, dtype=torch.int32) + break + + # Sort with the real length of the sequence + combined = sorted( + zip(input_ids, labels, images), + key=lambda x: len(x[2]) * (NUM_TOKENS_PER_IMAGE - 1) + x[0].size(-1), + reverse=True, # Start Packing from the sequence with most images. + ) + sorted_ids, sorted_labels, sorted_images = zip(*combined) + sorted_ids, sorted_labels, sorted_images = list(sorted_ids), list(sorted_labels), list(sorted_images) + max_seq_length = self.tokenizer.model_max_length # len(sorted_ids[0]) + max_sample_len = 0 + + batches = [] + label_batches = [] + position_ids = [] + batch_images = [] + seqlens_in_batch = [] + + i = 0 + while i < len(sorted_ids): + current_batch = torch.tensor([], dtype=torch.int32) + current_label_batch = torch.tensor([], dtype=torch.int32) + current_position_ids = torch.tensor([], dtype=torch.int32) + current_batch_images = [] + current_num_images = 0 + current_len = 0 + current_num_samples = 0 + + # Pack a few samples into one sample + while i < len(sorted_ids): + num_images = (sorted_ids[i] == image_token_id).sum().item() + num_image_tokens_added = num_images * (NUM_TOKENS_PER_IMAGE - 1) + num_incoming_tokens = sorted_ids[i].size(-1) + num_image_tokens_added + + # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree` + if self.ring_degree > 1: + RING_PAD_TOKEN_INDEX = 2 + if self.ring_type == "ring_varlen": + if num_incoming_tokens % self.sp_degree != 0: + pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + elif self.ring_type == "zigzag_ring_varlen": + self.zigzag_sp_degree = self.sp_degree * 2 + if num_incoming_tokens % self.zigzag_sp_degree != 0: + pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + else: + raise ValueError(f"Invalid ring_type: {self.ring_type}") + + if num_incoming_tokens > max_seq_length: + print( + f"Warning: Skipping one packed sample with {num_incoming_tokens} tokens,\ + please consider increase max seq len {max_seq_length}." + ) + i += 1 + continue + + if ( + (current_num_images == 0) + or (current_num_images < self.sp_degree) + or ( + (current_num_images + num_images <= max_num_images) + and (current_len + num_incoming_tokens <= max_sample_len) + ) + ) and (current_len + num_incoming_tokens <= max_seq_length): + current_num_images += num_images + current_len += num_incoming_tokens + current_num_samples += 1 + current_position_ids = torch.cat( + (current_position_ids, torch.arange(start=0, end=num_incoming_tokens)), dim=0 + ) + current_batch = torch.cat((current_batch, sorted_ids[i]), dim=0) + sorted_labels[i][0] = IGNORE_INDEX + current_label_batch = torch.cat((current_label_batch, sorted_labels[i]), dim=0) + seqlens_in_batch.append(num_incoming_tokens) + current_batch_images.extend(sorted_images[i]) + i += 1 + assert current_num_images == len(current_batch_images) + else: + break + + # Padding the sample with the dummy image sample, if there are no enough images + MAX_RETRY = self.sp_degree + num_retry = 0 + while current_num_images < self.sp_degree and current_len < max_seq_length and num_retry <= MAX_RETRY: + current_num_images += dummy_image.size(0) + current_len += dummy_seqlen + current_num_samples += 1 + current_position_ids = torch.cat((current_position_ids, dummy_position_ids), dim=0) + current_batch = torch.cat((current_batch, dummy_input_ids), dim=0) + current_label_batch = torch.cat((current_label_batch, dummy_labels), dim=0) + seqlens_in_batch.append(dummy_seqlen) + current_batch_images.extend(dummy_image) + # We pad from left side to ensure correct grad flow + # current_batch = torch.cat((dummy_input_ids, current_batch), dim=0) + # current_label_batch = torch.cat((dummy_labels, current_label_batch), dim=0) + # seqlens_in_batch.insert(0, dummy_seqlen) + # current_batch_images = torch.cat((dummy_image, current_batch_images), dim=0) + num_retry += 1 + + # Drop the samples that do not have enough images + if current_num_images < self.sp_degree: + print(f"Warning: Skipping one packed sample with {current_num_images} images") + seqlens_in_batch = seqlens_in_batch[:-current_num_samples] + continue + + max_sample_len = max(max_sample_len, current_len) + batches.append(current_batch) + label_batches.append(current_label_batch) + position_ids.append(current_position_ids) + batch_images.append(current_batch_images) + + try: + assert current_num_images == len(torch.where(current_batch == image_token_id)[0].tolist()) + except AssertionError: + print(f"Error num_images on {self.sp_rank}", current_num_images) + print("current_batch", current_batch) + print( + f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:", + len(torch.where(current_batch == image_token_id)[0].tolist()), + ) + print(f"Error len(current_batch_images) on {self.sp_rank}:", len(current_batch_images)) + raise AssertionError + + # Split for sequence parallelism + for i in range(len(batches)): + image_token_indices = torch.where(batches[i] == image_token_id)[0].tolist() + image_ids = torch.arange(0, len(image_token_indices), dtype=torch.int32) + batches[i] = extract_local_input_ids( + batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id + ) + label_batches[i] = extract_local_input_ids( + label_batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id + ) + batch_images[i] = torch.concat( + extract_local_from_list(batch_images[i], self.sp_rank, self.sp_degree), dim=0 + ) + H, W = batch_images[i].size(-2), batch_images[i].size(-1) + batch_images[i] = batch_images[i].reshape(-1, 3, W, H) + num_images = len(batch_images[i]) + + try: + assert num_images == len(torch.where(batches[i] == image_token_id)[0].tolist()) + except AssertionError: + print(f"Error num_images on {self.sp_rank}", num_images) + print("batches[i]", batches[i]) + print( + f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:", + len(torch.where(batches[i] == image_token_id)[0].tolist()), + ) + print(f"Error batch_images[i] on {self.sp_rank}:", batch_images[i].shape) + raise AssertionError + position_ids[i] = extract_local_position_ids( + position_ids[i], image_token_indices, image_ids, self.sp_rank, self.sp_degree, NUM_TOKENS_PER_IMAGE - 1 + ) + + input_ids = torch.nn.utils.rnn.pad_sequence( + batches, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence(label_batches, batch_first=True, padding_value=IGNORE_INDEX) + seqlens_in_batch = [torch.tensor(x) for x in seqlens_in_batch] + seqlens_in_batch = torch.stack(seqlens_in_batch, axis=0) + seqlens_in_batch = seqlens_in_batch.flatten() + position_ids = torch.nn.utils.rnn.pad_sequence(position_ids, batch_first=True, padding_value=-1) + + if batch_images: + batch_images = [torch.unbind(images) for images in batch_images] + flat_batch_images = [item for sublist in batch_images for item in sublist] + else: + flat_batch_images = None + batch = dict( + input_ids=input_ids, + labels=labels, + # notice that we inject attention mask here + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + seqlens_in_batch=seqlens_in_batch, + media={"image": flat_batch_images}, + media_config={"image": {}}, + position_ids=position_ids, + ) + return batch + + +def make_supervised_data_module( + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, +) -> Dict: + """Make dataset and collator for supervised fine-tuning. + This function is originally implemented by the LLaVA team and + modified by Jason Lu, Haotian Tang and Ligeng Zhu.""" + datasets_mixture.register_datasets_mixtures() + + from .builder import build_dataset + + train_dataset = build_dataset(data_args.data_mixture, data_args, training_args, tokenizer) + training_args.sample_lens = [len(d) for d in train_dataset.datasets] + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is None: + data_collator = DataCollator(tokenizer=tokenizer) + else: + sp_degree = training_args.seq_parallel_size + sp_rank = PROCESS_GROUP_MANAGER.sp_rank + ring_degree = PROCESS_GROUP_MANAGER.ring_degree + ring_type = PROCESS_GROUP_MANAGER.ring_type + data_collator = DataCollatorForSupervisedDatasetSeqParallel( + tokenizer=tokenizer, + data_args=data_args, + training_args=training_args, + sp_degree=sp_degree, + sp_rank=sp_rank, + ring_degree=ring_degree, + ring_type=ring_type, + ) + + return dict( + train_dataset=train_dataset, + data_collator=data_collator, + ) diff --git a/llava/data/datasets_mixture.py b/llava/data/datasets_mixture.py new file mode 100644 index 0000000000000000000000000000000000000000..62a53fcf5fe69256e5cd4b005ab7f28f63d9fe7c --- /dev/null +++ b/llava/data/datasets_mixture.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from dataclasses import dataclass, field + + +@dataclass +class Dataset: + dataset_name: str + dataset_type: str = field(default="torch") + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."}) + image_path: str = field(default=None, metadata={"help": "Path to the training image data."}) + speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."}) + caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."}) + description: str = field( + default=None, + metadata={ + "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset." + }, + ) + test_script: str = (None,) + maintainer: str = (None,) + ############## ############## ############## ############## ############## ############## + caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) + caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) + start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) + end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) + + +DATASETS_LEGACY = {} + + +def add_dataset(dataset): + if dataset.dataset_name in DATASETS_LEGACY: + # make sure the data_name is unique + warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.") + assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'." + DATASETS_LEGACY.update({dataset.dataset_name: dataset}) + + +def register_datasets_mixtures(): + ############## ############## ############## ############## ############## ############## + # Audio Datasets + ############## ############## ############## ############## ############## ############## + + data_mixture_1 = Dataset( + dataset_name="data_mixture_1", + dataset_type="torch", + data_path="/path/to/your/data_mixture_1/train.json", + ) + add_dataset(data_mixture_1) + + data_mixture_2 = Dataset( + dataset_name="data_mixture_2", + dataset_type="torch", + data_path="/path/to/your/data_mixture_2/train.json", + ) + add_dataset(data_mixture_2) + # Add more data mixtures below \ No newline at end of file diff --git a/llava/data/registry/datasets/audio_test.yaml b/llava/data/registry/datasets/audio_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..144c86b4bee8d3bbe67b32b7e7948e6f3c8e4bb9 --- /dev/null +++ b/llava/data/registry/datasets/audio_test.yaml @@ -0,0 +1,97 @@ +--- +Clotho-AQA-AQA: + _target_: llava.data.LLaVADataset + data_path: Clotho-AQA-AQA/test.json +Music-AVQA-AQA_All: + _target_: llava.data.LLaVADataset + data_path: Music-AVQA-AQA_All/test.json +CochlScene-SceneClassification: + _target_: llava.data.LLaVADataset + data_path: CochlScene-SceneClassification/test.json +NSynth-Source: + _target_: llava.data.LLaVADataset + data_path: NSynth-Source/test.json +NSynth-Instrument: + _target_: llava.data.LLaVADataset + data_path: NSynth-Instrument/test.json +FSD50k-EventClassification: + _target_: llava.data.LLaVADataset + data_path: FSD50k-EventClassification/test.json +Clotho-v2-AudioCaptioning: + _target_: llava.data.LLaVADataset + data_path: Clotho-v2-AudioCaptioning/test.json +audiocaps-AudioCaptioning: + _target_: llava.data.LLaVADataset + data_path: audiocaps-AudioCaptioning/test.json +ravdess-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: ravdess-EmotionClassification/val.json +GTZAN-GenreClassification: + _target_: llava.data.LLaVADataset + data_path: GTZAN-GenreClassification/test.json +UrbanSound8K-EventClassification: + _target_: llava.data.LLaVADataset + data_path: UrbanSound8K-EventClassification/train.json +Medley-solos-DB-InstrClassification: + _target_: llava.data.LLaVADataset + data_path: Medley-solos-DB-InstrClassification/test.json +ESC50-EventClassification: + _target_: llava.data.LLaVADataset + data_path: ESC50-EventClassification/train.json +CREMA-D-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: CREMA-D-EmotionClassification/test.json +IEMOCAP-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: IEMOCAP-EmotionClassification/test.json +MELD-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: MELD-EmotionClassification/test.json +MELD-SentimentClassification: + _target_: llava.data.LLaVADataset + data_path: MELD-SentimentClassification/test.json +MMAU: + _target_: llava.data.LLaVADataset + data_path: MMAU/test.json +MMAU-mini: + _target_: llava.data.LLaVADataset + data_path: MMAU/test-mini.json +AudioEntailmentQA: + _target_: llava.data.LLaVADataset + data_path: AudioEntailmentQA/test.json +SPGI-ASR: + _target_: llava.data.LLaVADataset + data_path: SPGI-ASR/val.json +SWBD-ASR: + _target_: llava.data.LLaVADataset + data_path: SWBD-ASR/val.json +LibriSpeech-ASR-clean: + _target_: llava.data.LLaVADataset + data_path: LibriSpeech-ASR/test_clean.json +LibriSpeech-ASR-other: + _target_: llava.data.LLaVADataset + data_path: LibriSpeech-ASR/test_other.json +VoxPopuli-ASR: + _target_: llava.data.LLaVADataset + data_path: VoxPopuli-ASR/test.json +Europarl-ASR: + _target_: llava.data.LLaVADataset + data_path: Europarl-ASR/test.json +CV-ASR: + _target_: llava.data.LLaVADataset + data_path: CV-ASR/test.json +GigaSpeech-ASR: + _target_: llava.data.LLaVADataset + data_path: GigaSpeech-ASR/test.json +CompA-R-AQA: + _target_: llava.data.LLaVADataset + data_path: CompA-R-AQA/test.json +MuschoMusicQA: + _target_: llava.data.LLaVADataset + data_path: MuschoMusicQA/test.json +CMM: + _target_: llava.data.LLaVADataset + data_path: CMM/test.json +AIR-Bench: + _target_: llava.data.LLaVADataset + data_path: AIR-Bench/test.json diff --git a/llava/data/registry/datasets/default.yaml b/llava/data/registry/datasets/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dc18dfebe7b092e029ba6661c41925d503477ae --- /dev/null +++ b/llava/data/registry/datasets/default.yaml @@ -0,0 +1,5 @@ +--- +dummy: + _target_: llava.data.DummyDataset + num_instances: 10000 + comments: dummy dataset for testing diff --git a/llava/data/registry/mixtures.yaml b/llava/data/registry/mixtures.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70f69accb593595bd978de144a8b571b6c24b3b6 --- /dev/null +++ b/llava/data/registry/mixtures.yaml @@ -0,0 +1,78 @@ +--- +audio_speech_all: + -CV-ASR_1 + -MELD-EmotionClassification+ + -BBCSoundEffects-AudioDescription + -SWBD-ASR_1 + -WavCaps-SoundBible-AudioCaptioning + -AudioSet-Speech-Audio-QA + -SONYC-UST-EventClassification + -VoxPopuli-ASR_1 + -FSD50k-EventClassification + -SalmonnQA + -emov-db-EmotionClassification + -LLARK_MagnaTagATune-mir+tess-EmotionClassification + -Europarl-ASR_1 + -jl-corpus-EmotionClassification + -Ego-10-AudioCaptioning + -SPGI-ASR_1 + -CREMA-D-EmotionClassification + -MusicBenchQA + -WavCaps-BBC_Sound_Effects-AudioCaptioning + -NSynth-Instrument + -SpokenSquadQA + -NSynth-MIR + -AudioEntailmentQA + -GigaSpeech-ASR_1 + -WavCaps-AudioSet_SL-AudioCaptioning + -NonSpeech7k-EventClassification + -chime-home-EventClassification + -MusicCaps-AudioCaptioning + -LP-MusicCaps-MSD-AudioCaptioning + -Ego-30-AudioCaptioning + -NSynth-Source+Clotho-v2-AudioCaptioning + -LP-MusicCaps-MC-AudioCaptioning + -Clotho-AQA-EventClassification + -WavCaps-FreeSound-AudioCaptioning + -LLARK_MagnaTagATune-reasoning + -AudioSet-Temporal-Speech-Audio-QA + -TUT-EventClassification + -ESC50-EventClassification + -WavText5K-Tagging + -MELD-SentimentClassification + -Music-AVQA-AQA_All + -Music-AVQA-AVQA_All + -MACS-AudioCaptioning + -Medley-solos-DB-InstrClassification + -AudioSet-EventClassification + -OMGEmotion-EmotionClassification + -FMA-GenreClassification + -Epidemic_sound-AudioCaptioning + -CochlScene-SceneClassification + -LLARK_FMA-reasoning + -ravdess-EmotionClassification + -CompA-R-AQA + -MU-LLAMA-AQA + -musdbhq-InstrClassification + -UrbanSound8K-EventClassification + -audiocaps-AudioCaptioning + -VocalSound-VocalClassification + -CLAP_freesound-AudioCaptioning + -MMAUQA + -SongDescriber-AudioCaptioning + -HeySQuADQA + -Mira-AudioCaptioning + -Clotho-AQA-AQA + -LibriSpeech-ASR_1 + -IEMOCAP-EmotionClassification + -AudioSetFullwoAudioMusicCaps-EventClassification + -MSP-PODCAST-Publish-1.9-EmotionClassification + -OpenAQA-AQA + -SoundDescs-AudioDescription + -LibriSQA + -LLARK_FMA-mir + -LP-MusicCaps-MTT-AudioCaptioning + -GTZAN-GenreClassification + -musdbhq-captioning + -YesNoQA + diff --git a/llava/entry.py b/llava/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..44348e95eea2598a030f6953ff2098e133bed68f --- /dev/null +++ b/llava/entry.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import typing +from typing import List, Optional + +if typing.TYPE_CHECKING: + from transformers import PreTrainedModel +else: + PreTrainedModel = None + +__all__ = ["load"] + + +def load( + model_path: str, + model_base: Optional[str] = None, + devices: Optional[List[int]] = None, + **kwargs, +) -> PreTrainedModel: + import torch + + from llava.conversation import auto_set_conversation_mode + from llava.mm_utils import get_model_name_from_path + from llava.model.builder import load_pretrained_model + + auto_set_conversation_mode(model_path) + + model_name = get_model_name_from_path(model_path) + model_path = os.path.expanduser(model_path) + if os.path.exists(os.path.join(model_path, "model")): + model_path = os.path.join(model_path, "model") + + # Set `max_memory` to constrain which GPUs to use + if devices is not None: + assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set" + kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices}) + + model = load_pretrained_model(model_path, model_name, model_base, **kwargs)[1] + return model diff --git a/llava/eval/__init__.py b/llava/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d807fb5a2b32ea471a6bfec0d57694f57f91e377 --- /dev/null +++ b/llava/eval/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import os + +from llava.utils import io + +__all__ = ["EVAL_ROOT", "TASKS"] + + +EVAL_ROOT = "scripts/eval" +TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry_audio.yaml")) diff --git a/llava/eval/eval_audio_bench.py b/llava/eval/eval_audio_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..b3017aef3ba10519f507e4d878d834c7ab3d7dbf --- /dev/null +++ b/llava/eval/eval_audio_bench.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import csv +import itertools +import json +import os + +import torch +from datasets import load_dataset +from tqdm import tqdm + +import llava +from llava import conversation as conversation_lib +from llava.data.builder import DATASETS +from llava.eval.mmmu_utils.eval_utils import parse_choice +from llava.utils import distributed as dist +from llava.utils import io +from llava.utils.logging import logger + + +def load_existing_ids(output_file): + if not os.path.exists(output_file): + return set(), [] + try: + with open(output_file, "r") as f: + lines = f.readlines() + outputs = [json.loads(line) for line in lines] + processed_ids = {item["id"] for item in outputs} + return processed_ids, outputs + except Exception as e: + print(f"Error loading existing outputs: {e}") + return set(), [] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default="auto") + parser.add_argument("--generation-config", type=json.loads) + parser.add_argument("--output-dir", type=str, default=None) + args = parser.parse_args() + + # Set up distributed environment + dist.init() + devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size()) + torch.cuda.set_device(devices[0]) + + # Load stage 3 model with line 56 + model = llava.load(args.model_base, model_base=None, devices=devices) + # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode + # model = PeftModel.from_pretrained( + # model, + # args.model_path, + # device_map="auto", + # torch_dtype=torch.float16, + # ) + # Set up generation config + generation_config = model.default_generation_config + if args.generation_config is not None: + generation_config.update(**args.generation_config) + + # Load data and chunk it + json_file = DATASETS[args.task]["data_path"] + instances = io.load(json_file) + instances = instances[dist.rank() :: dist.size()] + + output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl") + processed_ids, outputs = load_existing_ids(output_path) + + count = len(outputs) + # Run inference + new_outputs = [] + for instance in tqdm(instances, disable=not dist.is_main()): + uuid = instance["id"] + sound_path = instance["sound"] + + if sound_path in processed_ids: + continue # Skip if already processed + sound = llava.Sound(sound_path) + conversations = instance["conversations"] + question = conversations[0]["value"] + + response = model.generate_content([sound, question], generation_config=generation_config) + + print("response", response) + + output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response} + new_outputs.append(output) + count = count +1 + if count % 20 == 0: + # Gather and save outputs + if dist.size() > 1: + outputs_new = dist.gather(new_outputs, dst=0) + if dist.is_main(): + outputs_new = list(itertools.chain(*outputs_new)) + final_outputs = outputs + outputs_new + io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) + else: + final_outputs = outputs + new_outputs + io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) + if dist.size() > 1: + new_outputs = dist.gather(new_outputs, dst=0) + if not dist.is_main(): + return + new_outputs = list(itertools.chain(*new_outputs)) + final_outputs = outputs + new_outputs + io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs) + +if __name__ == "__main__": + main() diff --git a/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62c020ee5b25a39cc6679b0f7f28235d1db3120c Binary files /dev/null and b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc differ diff --git a/llava/eval/mmmu_utils/eval_utils.py b/llava/eval/mmmu_utils/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..42c3cc6bc917ffa3f5af7ed7c2ec601ef8dbfbc1 --- /dev/null +++ b/llava/eval/mmmu_utils/eval_utils.py @@ -0,0 +1,61 @@ +# This file is originated from the official MMMU codebase: +# https://github.com/MMMU-Benchmark/MMMU + +import random + +import numpy as np + + +def parse_choice(response, all_choices, index2ans=None): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f" {choice} " in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + index = response.rfind(f"({can})") + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = [generated_response.index(f'({can})') for can in candidates] + else: + for can in candidates: + index = response.rfind(f" {can} ") + start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the last one + pred_index = candidates[np.argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index diff --git a/llava/eval/registry_audio.yaml b/llava/eval/registry_audio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8de55c73c3d12785044fb0cbd09c847eb548a514 --- /dev/null +++ b/llava/eval/registry_audio.yaml @@ -0,0 +1,93 @@ +Clotho-AQA-AQA: + tags: + - local +Music-AVQA-AQA_All: + tags: + - local +CochlScene-SceneClassification: + tags: + - local +NSynth-Source: + tags: + - local +NSynth-Instrument: + tags: + - local +FSD50k-EventClassification: + tags: + - local +Clotho-v2-AudioCaptioning: + tags: + - local +audiocaps-AudioCaptioning: + tags: + - local +ravdess-EmotionClassification: + tags: + - local +GTZAN-GenreClassification: + tags: + - local +UrbanSound8K-EventClassification: + tags: + - local +Medley-solos-DB-InstrClassification: + tags: + - local +ESC50-EventClassification: + tags: + - local +CREMA-D-EmotionClassification: + tags: + - local +IEMOCAP-EmotionClassification: + tags: + - local +MELD-EmotionClassification: + tags: + - local +MELD-SentimentClassification: + tags: + - local +MMAU: + tags: + - local +AudioEntailmentQA: + tags: + - local +SPGI-ASR: + tags: + - local +SWBD-ASR: + tags: + - local +LibriSpeech-ASR-clean: + tags: + - local +LibriSpeech-ASR-other: + tags: + - local +VoxPopuli-ASR: + tags: + - local +Europarl-ASR: + tags: + - local +CV-ASR: + tags: + - local +GigaSpeech-ASR: + tags: + - local +CompA-R-AQA: + tags: + - local +MuschoMusicQA: + tags: + - local +CMM: + tags: + - local +AIR-Bench: + tags: + - local diff --git a/llava/media.py b/llava/media.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcc0aa44034212a4649b9ef256eb5d995e48380 --- /dev/null +++ b/llava/media.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +__all__ = ["Media", "File", "Image", "Video", "Speech", "Sound"] + + +class Media: + pass + + +class File(Media): + def __init__(self, path: str) -> None: + self.path = path + + +class Image(File): + pass + + +class Video(File): + pass + + +class Speech(File): + pass + +class Sound(File): + pass \ No newline at end of file diff --git a/llava/mm_utils.py b/llava/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4a23d338f9a4440aeaa6f44f7f5fb0d89093fe --- /dev/null +++ b/llava/mm_utils.py @@ -0,0 +1,641 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL + +import base64 +import os +import tempfile +from io import BytesIO + +import numpy as np +import torch +from PIL import Image +from transformers import StoppingCriteria + +from pydub import AudioSegment +from torchvision import transforms +import soundfile as sf +from librosa import resample as librosa_resample +import whisper +import random +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds +def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None): + import cv2 + + if fps == None or frame_count == None: + # if one of fps or frame_count is None, still recompute + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + if fps == 0 or frame_count == 0: + print(f"Video file not found. return empty images. {video_file_name}") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames, 0, [0.] + + duration = frame_count / fps + frame_interval = frame_count // num_frames + if frame_interval == 0 and frame_count <= 1: + print(f"frame_interval is equal to 0. return empty image. {video_file_name}") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames, 0, [0.] + # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) + + images = [] + count = 0 + success = True + frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) + frame_times = [frame / fps for frame in frame_indices] + while success: + # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval) + if frame_count >= num_frames: + success, frame = vidcap.read() + if count in frame_indices: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except BaseException: + continue + if len(images) >= num_frames: + return images, num_frames, frame_times + count += 1 + else: + # Left padding frames if the video is not long enough + success, frame = vidcap.read() + if success: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except BaseException: + continue + count += 1 + else: + break + if len(images) == 0: + raise ValueError("Did not find enough frames in the video. return empty image.") + + return images, len(images), frame_times + + +def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None): + """ + num_frames is the max number of frames the model can support. + frame_count is the number of frames in the input video. + max_fps is the max FPS of the model can support. + fps is the fps of the input video. + """ + + import random + + import cv2 + + if fps == None or frame_count == None: + # if one of fps or frame_count is None, still recompute + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if fps == 0 or frame_count == 0: + print(f"Video file not found. return empty images. {video_file_name}") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + + duration = frame_count / fps + # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps) + # If the video is too long (longer than max_fps and num_frames can support), + # we will use lower fps to sample frames. + if duration >= num_frames / max_fps: + frame_interval = frame_count // num_frames + + # If the video is too short, we will skip the video if there is only one frame. + if frame_interval == 0 and frame_count <= 1: + print(f"frame_interval is equal to 0. return empty image. {video_file_name}") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + + images = [] + count = 0 + success = True + frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) + frame_times = [frame / fps for frame in frame_indices] + while success: + if frame_count >= num_frames: + # success, frame = vidcap.read() + if count in frame_indices: + success, frame = vidcap.read() + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + # print("Failed to read frame:", count) + continue + if len(images) >= num_frames: + return images, num_frames, frame_times + else: + success = vidcap.grab() + count += 1 + else: + # Left padding frames if the video is not long enough + success, frame = vidcap.read() + if success: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + # print("Failed to read frame:", count) + continue + count += 1 + else: + break + else: + frames_required = int(duration * max_fps) + frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int) + if frames_required == 0: + print(f"frames_required is fewer than 2. Duration {duration}, return empty image.") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + elif frames_required == 1: + frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int) + images = [] + count = 0 + looked = 0 + success = True + + while success: + success, frame = vidcap.read() + if success and (looked in frame_indices): + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + continue + count += 1 + looked += 1 + frame_times = [frame / fps for frame in frame_indices] + + if len(images) == 0: + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + else: + return images, len(images), frame_times + + +def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None): + """ + Extract frames from a video using OpenCV. + + Args: + vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. + frames (int): Number of frames to extract from the video. + fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals. + + Returns: + list: List of PIL Images extracted from the video. + + Raises: + NotImplementedError: If the type of `vpath_or_bytesio` is not supported. + """ + import cv2 + if isinstance(vpath_or_bytesio, str): + vidcap = cv2.VideoCapture(vpath_or_bytesio) + if max_fps > 0.0: + return get_frame_from_vcap_with_fps( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio + ) + return get_frame_from_vcap( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio + ) + elif isinstance(vpath_or_bytesio, (BytesIO,)): + # assuming mp4 + with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: + temp_video.write(vpath_or_bytesio.read()) + temp_video_name = temp_video.name + vidcap = cv2.VideoCapture(temp_video_name) + if max_fps > 0.0: + return get_frame_from_vcap_with_fps( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name + ) + return get_frame_from_vcap( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name + ) + else: + raise NotImplementedError(type(vpath_or_bytesio)) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + """ + Expand the given PIL image to a square shape by adding padding. + + Parameters: + - pil_img: The PIL image to be expanded. + - background_color: The color of the padding to be added. + + Returns: + - The expanded PIL image. + + If the image is already square, it is returned as is. + If the image is wider than it is tall, padding is added to the top and bottom. + If the image is taller than it is wide, padding is added to the left and right. + """ + width, height = pil_img.size + if pil_img.mode == "L": + background_color = background_color[0] + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale + + processed_images = [] + + ########################################################################################## + ############# Add tiles for all but the last scale using fixed squre ratio ############### + ########################################################################################## + + for scale in s2_scales[:-1]: + target_width = image_size * (scale // s2_scales[0]) + target_height = image_size * (scale // s2_scales[0]) + blocks = (scale // s2_scales[0]) ** 2 + + # resize the image + resized_img = image.resize((target_width, target_height)) + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + ########################################################################################## + ################ Add tiles for the last scale using dynamic aspect ratio ################# + ########################################################################################## + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0]) + + + +def dynamic_s2_process_images_and_prompt(images, data_args, image_folder=None): + idx = 0 + all_images = [] + all_block_size = [] + for img in images: + processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True) + all_images.append(processed_images) + all_block_size.append(block_size) + idx += 2 + if all_images: + all_images = torch.cat(all_images) + else: + all_images = None + return all_images, all_block_size + + +def process_image( + image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None +): + processor = data_args.image_processor + if isinstance(image_file, str): + if image_folder is not None: + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + else: + # image is stored in bytearray + image = image_file + image = image.convert("RGB") + if hasattr(data_args.image_processor, "crop_size"): + # CLIP vision tower + crop_size = data_args.image_processor.crop_size + else: + # SIGLIP vision tower + assert hasattr(data_args.image_processor, "size") + crop_size = data_args.image_processor.size + if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2: + assert crop_size["height"] == crop_size["width"] + images, block_size = dynamic_s2_preprocess( + image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"] + ) + images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images] + return torch.stack(images), block_size + if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res: + assert crop_size["height"] == crop_size["width"] + if max_tiles is not None: + max_num = max_tiles + else: + max_num = data_args.max_tiles + images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"]) + images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images] + return torch.stack(images) + + if data_args.image_aspect_ratio == "resize": + image = image.resize((crop_size["width"], crop_size["height"])) + if data_args.image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + # Using default behavior of the vision encoder + # For CLIP, default is central crop + # For Radio, default is central crop + # For Siglip, default is resize + # For InternVIT, default is resize + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + return image + +def get_num_windows(T, sr, max_num_window=5): + + window_length = int(30.0 * sr) + window_overlap = int(0.0 * sr) + max_num_window = max_num_window + + num_windows = 1 + if T <= window_length: + num_windows = 1 + full_length = window_length + elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): + num_windows = max_num_window + full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) + else: + num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) + full_length = num_windows * window_length - (num_windows - 1) * window_overlap + + return num_windows, full_length + +def load_audio(file_path, target_sr=16000, duration=30.0, start=0.0): + if file_path.endswith('.mp3'): + audio = AudioSegment.from_file(file_path) + if len(audio) > (start + duration) * 1000: + audio = audio[start * 1000:(start + duration) * 1000] + + if audio.frame_rate != target_sr: + audio = audio.set_frame_rate(target_sr) + + if audio.channels > 1: + audio = audio.set_channels(1) + + data = np.array(audio.get_array_of_samples()) + if audio.sample_width == 2: + data = data.astype(np.float32) / np.iinfo(np.int16).max + elif audio.sample_width == 4: + data = data.astype(np.float32) / np.iinfo(np.int32).max + else: + raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) + + else: + with sf.SoundFile(file_path) as audio: + original_sr = audio.samplerate + channels = audio.channels + + max_frames = int((start + duration) * original_sr) + + audio.seek(int(start * original_sr)) + frames_to_read = min(max_frames, len(audio)) + data = audio.read(frames_to_read) + + if data.max() > 1 or data.min() < -1: + data = data / max(abs(data.max()), abs(data.min())) + + if original_sr != target_sr: + if channels == 1: + data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) + else: + data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] + else: + if channels != 1: + data = data.T[0] + + if data.min() >= 0: + data = 2 * data / abs(data.max()) - 1.0 + else: + data = data / max(abs(data.max()), abs(data.min())) + + assert len(data.shape) == 1, data.shape + return data + +def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None): + model_cfg.image_processor = image_processor + new_images = [ + process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles) + for image in images + ] + + if all(x.shape == new_images[0].shape for x in new_images): + if len(new_images[0].shape) == 4: + new_images = torch.cat(new_images, dim=0) + elif len(new_images[0].shape) == 3: + new_images = torch.stack(new_images, dim=0) + else: + raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}") + else: + raise ValueError("The shape of images in new_images is different!") + return new_images + +def process_sounds(sounds): + sounds = torch.tensor(sounds) + return sounds + +def process_sound_masks(masks): + masks = torch.tensor(masks[0]) + return masks + +def tokenizer_image_token(prompt, tokenizer, return_tensors=None): + return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] + +def is_gemma_tokenizer(tokenizer): + return "gemma" in tokenizer.__class__.__name__.lower() + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) diff --git a/llava/model/FloatPointQuantizeTorch.py b/llava/model/FloatPointQuantizeTorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b01c6f8ad008cdcd01c51a85399a194032df05b1 --- /dev/null +++ b/llava/model/FloatPointQuantizeTorch.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import math + +import torch + + +def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit + expo = torch.floor(torch.log2(x_abs)) + expo = torch.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / torch.exp2(expo) + + mant_int = torch.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + if stochastic: + noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5) + mant_frac.add_(noise) + mant_frac = torch.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * (2**expo) * mant_q + y = y.to(x) + + return y + + +def floatExM0_quantize_torch(x, e_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1) + expo = torch.log2(x_abs) + if stochastic: + noise = expo.new(expo.shape).uniform_(-0.5, 0.5) + expo.add(noise) + log_bias = math.log2(4 / 3) - 1 / 2 + expo.add(torch.ones_like(expo) * log_bias) + expo = torch.clamp(expo, min=Elow - 1, max=Ehigh) + expo = torch.round(expo) + + y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0 + y = y.to(x) + + return y + + +def Dynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + zero_mask = y.abs() > 1.01 * 10 ** (1 - bit) + y = y * zero_mask + y = y.to(x) + return y + + +def ZeroDynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + y = y.to(x) + return y diff --git a/llava/model/FloatPointQuantizeTriton.py b/llava/model/FloatPointQuantizeTriton.py new file mode 100644 index 0000000000000000000000000000000000000000..e619c3739ec968c92460b3d8d9efd2736fd07d44 --- /dev/null +++ b/llava/model/FloatPointQuantizeTriton.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import math +import struct + +import numpy as np +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +segment_size = 1024**3 + + +def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic): + x_ori_shape = x.shape + x = x.view(-1) + + n_elements = x.numel() + + if n_elements <= segment_size: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.empty_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + torch.cuda.synchronize() + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + else: # Triton will break when x.numel > 2 * 1024 ** 3 + num_segments = n_elements // segment_size + 1 + split_size = [segment_size] * (num_segments - 1) + [n_elements - segment_size * (num_segments - 1)] + x_list = x.split(split_size) + y_list = [] + del x + + for x in x_list: + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.empty_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + torch.cuda.synchronize() + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + + y_list.append(y) + y = torch.concat(y_list) + del y_list + + y = y.reshape(x_ori_shape) + return y + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_warps=4, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_quantize_kernel( + x_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + # mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_warps=4, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_stochastic_quantize_kernel( + x_ptr, + noise_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + noise = tl.load(noise_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) diff --git a/llava/model/__init__.py b/llava/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5494e136efbc55e4813a40184308c26fe5949b --- /dev/null +++ b/llava/model/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel + +# FP8 related comments, development in progress (PI: ligeng zhu, haochen xi) +# NOTE: VLM + LLM +# from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel +# NOTE: Linear -> fp8, similar to transformer engine +# from .language_model.qllama import QLlamaConfig, QLlamaForCausalLM, QLlamaModel +# NOTE: Linear + Activation -> fp8, haochen's iclr version +# from .language_model.qmemllama import QMemLlamaConfig, QMemLlamaForCausalLM, QMemLlamaModel +""" +TODO: + linear(weights): + simulated fp8: done + real fp8: in-progress (code already implmented) + activation: + simulated fp8: done + real fp8: in-progress (still coding) + optimizers: + current VILA: bf16 + simulated fp8: done + real fp8 + fsdp (single node): done + real fp8 + fsdp (multiple node): in-progress +1. linear fp8 +2. activation fp8 +3. fp8 infernce example (load directly from a fp8 and fwd) +4. bind fp8 related configs to QLlamaConfig {"coat_fp8_args": {}} +""" +from .language_model.fp8linearqwen2 import FP8LinearQwen2Config, FP8LinearQwen2Model +from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..767d0612da9a1fd77058e3d004d16e45d572e3ca --- /dev/null +++ b/llava/model/apply_delta.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +""" +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llava import LlavaLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in [ + "model.mm_projector.weight", + "model.mm_projector.bias", + ], f"{name} not in base model" + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in [ + "model.embed_tokens.weight", + "lm_head.weight", + ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" + bparam = base.state_dict()[name] + param.data[: bparam.shape[0], : bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/llava/model/builder.py b/llava/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4c715ddf2eaeaa59d9807b6fdc3d31a71005067c --- /dev/null +++ b/llava/model/builder.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PretrainedConfig + +from llava.model import LlavaLlamaModel +from llava.model.utils import is_mm_model + + +def load_pretrained_model( + model_path, + model_name, + model_base=None, + load_8bit=False, + load_4bit=False, + device_map="auto", + device="cuda", + **kwargs, +): + kwargs = {"device_map": device_map, **kwargs} + + if device != "cuda": + kwargs["device_map"] = {"": device} + + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + else: + kwargs["torch_dtype"] = torch.float16 + # kwargs["torch_dtype"] = torch.bfloat16 + + if is_mm_model(model_path): + # Load LLaVA model + ## TODO @yunhao: mind fixing lora + if "lora" in model_name.lower() and model_base is None: + warnings.warn( + "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." + ) + if ("lora" in model_name.lower() or "dora" in model_name.lower()) and model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + print(lora_cfg_pretrained) + print("Loading LLaVA from base model...") + config = AutoConfig.from_pretrained(model_base) + prepare_config_for_eval(config, kwargs) + model = LlavaLlamaModel.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs) + tokenizer = model.tokenizer + token_num, tokem_dim = model.llm.lm_head.out_features, model.llm.lm_head.in_features + if model.llm.lm_head.weight.shape[0] != token_num: + model.llm.lm_head.weight = torch.nn.Parameter( + torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype) + ) + model.llm.embed_tokens.weight = torch.nn.Parameter( + torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype) + ) + + print("Loading additional LLaVA weights...") + if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load( + os.path.join(model_path, "non_lora_trainables.bin"), + map_location="cpu", + ) + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder) + return torch.load(cache_file, map_location="cpu") + + non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin") + non_lora_trainables = { + (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items() + } + if any(k.startswith("model.model.") for k in non_lora_trainables): + non_lora_trainables = { + (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items() + } + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + + print("Loading LoRA weights...") + model = PeftModel.from_pretrained(model, model_path) + print("Merging LoRA weights...") + model = model.merge_and_unload() + print("Model is loaded...") + else: + config = AutoConfig.from_pretrained(model_path) + config.resume_path = model_path + prepare_config_for_eval(config, kwargs) + model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) + tokenizer = model.tokenizer + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print("Convert to FP16...") + model.to(torch.float16) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, legacy=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + model.eval() + image_processor = None + if is_mm_model(model_path): + model.resize_token_embeddings(len(tokenizer)) + + if hasattr(model.llm.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len + + +def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): + try: + # compatible with deprecated config convention + if getattr(config, "vision_tower_cfg", None) is None: + config.vision_tower_cfg = config.mm_vision_tower + except AttributeError: + raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}") + + config.model_dtype = kwargs.pop("torch_dtype").__str__() diff --git a/llava/model/coat/activation/__init__.py b/llava/model/coat/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69923298db60fa8275bf22895fd5c54cf102e9b9 --- /dev/null +++ b/llava/model/coat/activation/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py new file mode 100644 index 0000000000000000000000000000000000000000..04acd2fb73f707152839e1a97acb67e6f8df873b --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + + +def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit + expo = torch.floor(torch.log2(x_abs)) + expo = torch.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / torch.exp2(expo) + + mant_int = torch.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + if stochastic: + noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5) + mant_frac.add_(noise) + mant_frac = torch.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * (2**expo) * mant_q + y = y.to(x) + + return y + + +def floatExM0_quantize_torch(x, e_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1) + expo = torch.log2(x_abs) + if stochastic: + noise = expo.new(expo.shape).uniform_(-0.5, 0.5) + expo.add(noise) + log_bias = math.log2(4 / 3) - 1 / 2 + expo.add(torch.ones_like(expo) * log_bias) + expo = torch.clamp(expo, min=Elow - 1, max=Ehigh) + expo = torch.round(expo) + + y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0 + y = y.to(x) + + return y + + +def Dynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + zero_mask = y.abs() > 1.01 * 10 ** (1 - bit) + y = y * zero_mask + y = y.to(x) + return y + + +def ZeroDynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + y = y.to(x) + return y diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py new file mode 100644 index 0000000000000000000000000000000000000000..db85e7fc42dc54539f54400ccff280402e126f0d --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math +import struct + +import numpy as np +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + + +def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.zeros_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + + return y + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_stages=1, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_quantize_kernel( + x_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + # mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_stages=1, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_stochastic_quantize_kernel( + x_ptr, + noise_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + noise = tl.load(noise_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) diff --git a/llava/model/coat/activation/fake_quantization/quantize_function.py b/llava/model/coat/activation/fake_quantization/quantize_function.py new file mode 100644 index 0000000000000000000000000000000000000000..6b18d45603ecd3f6564a8dbac3810d8ce701fbb6 --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/quantize_function.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import re + +import torch + +from .FloatPointQuantizeTorch import * +from .FloatPointQuantizeTriton import * + + +def block_cut(input, row_block, column_block, pad_block=False): + # print(input.shape) + original_shape = input.shape + # input tensor shape is M * N + if len(input.shape) > 2: + input = input.reshape(-1, input.shape[2]) + elif len(input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block cut, {input}") + M, N = input.shape[0], input.shape[1] + + if row_block == -1: + row_block = M + if column_block == -1: + column_block = N + + if pad_block: + row_remainder, col_remainder = M % row_block, N % column_block + if row_remainder: + row_pad = row_block - row_remainder + else: + row_pad = 0 + if col_remainder: + col_pad = column_block - col_remainder + else: + col_pad = 0 + + input = torch.nn.functional.pad( + input, (0, col_pad, 0, row_pad), "constant", 0 + ) # refer to torch's doc to see why + M, N = input.shape[0], input.shape[1] + row_num, column_num = M // row_block, N // column_block + else: + row_num, column_num = M // row_block, N // column_block + + assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}" + assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}" + # print(input.shape) + input = ( + input.reshape(row_num, row_block, column_num, column_block) + .permute(0, 2, 1, 3) + .reshape(row_num * column_num, row_block, column_block) + ) + # print(input.shape) + return input + + +def block_reshape(input, origin_input, row_block, column_block, pad_block=False): + if len(origin_input.shape) > 2: + flatten_input = origin_input.reshape(-1, origin_input.shape[2]) + elif len(origin_input.shape) == 2: + flatten_input = origin_input + else: + raise ValueError(f"input shape {input.shape} does not match for block cut") + + M, N = flatten_input.shape[0], flatten_input.shape[1] + + if row_block == -1: + row_block = M + if column_block == -1: + column_block = N + + if pad_block: + row_remainder, col_remainder = M % row_block, N % column_block + if row_remainder: + row_pad = row_block - row_remainder + else: + row_pad = 0 + if col_remainder: + col_pad = column_block - col_remainder + else: + col_pad = 0 + + pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0) + M, N = pad_origin_input.shape[0], pad_origin_input.shape[1] + row_num, column_num = M // row_block, N // column_block + else: + row_num, column_num = M // row_block, N // column_block + + input = ( + input.reshape(row_num, column_num, row_block, column_block) + .permute(0, 2, 1, 3) + .reshape(row_num * row_block, column_num * column_block) + ) + + M, N = flatten_input.shape[0], flatten_input.shape[1] + input = input[:M, :N] + + if len(origin_input.shape) > 2: + input = input.reshape(origin_input.shape) + elif len(origin_input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block reshape") + + return input + + +def block_verify_int8(input, row_block, column_block, layer_type, necessary=True): + Binput = block_cut(input, row_block, column_block) + Binput = Binput.to(torch.float32) + + for n in range(Binput.shape[0]): + unique_values = len(torch.unique(Binput[n, :, :])) + if unique_values > 256: + if necessary: + raise ValueError(f"{layer_type} contains more than 256 unique values.") + else: + return False + return True + + +def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name): + Quant_fn = SymmQuantizer + return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name) + + +def extract_bit(string): + match = re.match(r"INT(\d+)", string) # INT8 + if match: + return "integer", int(match.group(1)), None + match = re.match(r"E(\d+)M(\d+)", string) # E4M3 / E5M2 + if match: + Ebit, Mbit = int(match.group(1)), int(match.group(2)) + if Ebit == 1: + return "integer", Mbit + 1, None + if Mbit == 0: + return "floatExM0", int(match.group(1)), 0 + return "floatExMy", int(match.group(1)), int(match.group(2)) + match = re.match(r"DE(\d+)", string) + if match: + return "Dynamic", int(match.group(1)), None + match = re.match(r"ZeroD(\d+)", string) + if match: + return "ZeroDynamic", int(match.group(1)), None + raise ValueError(f"{string} data format is not supported") + + +class SymmQuantizer(torch.autograd.function.InplaceFunction): + @staticmethod + def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None): + with torch.no_grad(): + absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon + + if bits == "100" or not apply_quantize: + return input, input, torch.ones_like(absmax_per_block) + elif bits == "FP32": + return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block) + elif bits == "FP16": + return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block) + elif bits == "BF16": + return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block) + else: + QuantType, bit1, bit2 = extract_bit(bits) + if not symm: + bit1 = bit1 + 1 # pretend to be asymmtric + + if QuantType == "integer": + Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1 + elif QuantType == "floatExMy": + Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * ( + 2 ** (2 ** (bit1 - 1)) + ) + if bit1 == 4 and bit2 == 3: # E4M3 + Qn, Qp = -448, 448 + if bit1 == 5 and bit2 == 2: # E5M2 + Qn, Qp = -57344, 57344 + elif QuantType == "floatExM0": + Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1)) + elif QuantType == "Dynamic": + Qn, Qp = -1, 1 + elif QuantType == "ZeroDynamic": + Qn, Qp = -1, 1 + else: + raise NotImplementedError(f"{bits} is not supported by quantization") + scale_per_block = (2 * absmax_per_block) / (Qp - Qn) + scale_per_block = scale_per_block.to(input) + + Qinput = input / scale_per_block + + if QuantType == "integer": + if stochastic: + noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5) + Qinput.add_(noise) + Qinput.clamp_(Qn, Qp).round_() + elif QuantType == "floatExMy": + # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic) + Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic) + elif QuantType == "floatExM0": + Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic) + else: + raise NotImplementedError(f"{bits} is not supported by quantization") + + RQinput = Qinput * scale_per_block + + if input.dtype != Qinput.dtype: + print( + f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}", + file=open("debug.txt", "a"), + ) + import IPython + + IPython.embed() + return RQinput, Qinput, scale_per_block + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None, None, None, None, None diff --git a/llava/model/coat/activation/fake_quantization/utils.py b/llava/model/coat/activation/fake_quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4ec2dae1ef7be1dc7c82c79439b7091d49feb6 --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def list_has_common_element(list1, list2): + set1 = set(list1) + set2 = set(list2) + return len(set1.intersection(set2)) > 0 + + +def calculate_scale_num(input, row_block, col_block): + if len(input.shape) > 2: + input = input.reshape(-1, input.shape[2]) + elif len(input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block cut, {input}") + M, N = input.shape[0], input.shape[1] + + if row_block == -1: + row_block = M + if col_block == -1: + col_block = N + + return input.numel() / (row_block * col_block) + + +def quant_get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK") or 0) + + +def format_string_with_condition( + input_string, + condition_config, + symm, + bits, + blocksize_config, + input_pad=20, +): + padded_string = input_string.ljust(input_pad) + output_string = padded_string + + for k, v in condition_config.items(): + if v: + output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6) + else: + output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6) + + output_string = output_string + f"Symm {symm}".ljust(10) + + for k, v in bits.items(): + output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10) + for k, v in blocksize_config.items(): + output_string += f"{k}: {v}".ljust(15) + + return output_string + + +def print_warning(sentence): + print("*" * (len(sentence) + 4)) + print(f"* {sentence} *") + print("*" * (len(sentence) + 4)) + + +def check_nan_inf(tensor, check_nan, check_inf): + if check_nan: + contain_nan = torch.isnan(tensor).any() + else: + contain_nan = False + if check_inf: + contain_inf = torch.isinf(tensor).any() + else: + contain_inf = False + return contain_nan, contain_inf + + +def move_torch_to_numpy(tensor): + if tensor is None: + return None + + if tensor.is_cuda: + tensor = tensor.cpu() + return tensor.detach().float().numpy() + + +def flatten_to_1d(tensor): + if tensor is None: + return None + + return tensor.reshape(-1) diff --git a/llava/model/coat/activation/models/_fp8_quantization_config.py b/llava/model/coat/activation/models/_fp8_quantization_config.py new file mode 100644 index 0000000000000000000000000000000000000000..579df908203575abf8da9565995640c4ae459e3b --- /dev/null +++ b/llava/model/coat/activation/models/_fp8_quantization_config.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass + +from transformers import PretrainedConfig + + +@dataclass +class QuantizationConfig: + quantize_model: str = "false" + symm: bool = True + epsilon: float = 1e-10 + fabit: str = "E4M3" + fwbit: str = "E4M3" + fobit: str = "E4M3" + babit: str = "E5M2" + bwbit: str = "E5M2" + bobit: str = "E5M2" + qchoice: str = "none" + group_size: int = -1 + pad_to_multiple_of: int = 0 + weight_memory_efficient: bool = True + + # Legacy + row_blocksize: int = -1 + col_blocksize: int = -1 + + def __init__( + self, + quantize_model: str = "false", + symm: bool = True, + epsilon: float = 1e-10, + fabit: str = "E4M3", + fwbit: str = "E4M3", + fobit: str = "E4M3", + babit: str = "E5M2", + bwbit: str = "E5M2", + bobit: str = "E5M2", + qchoice: str = "none", + group_size: int = -1, + pad_to_multiple_of: int = 0, + weight_memory_efficient: bool = True, + row_blocksize: int = -1, + col_blocksize: int = -1, + **kwargs, + ): + super().__init__() + self.quantize_model = quantize_model + self.symm = symm + self.epsilon = epsilon + self.fabit = fabit + self.fwbit = fwbit + self.fobit = fobit + self.babit = babit + self.bwbit = bwbit + self.bobit = bobit + self.qchoice = qchoice + self.group_size = group_size + self.pad_to_multiple_of = pad_to_multiple_of + self.weight_memory_efficient = weight_memory_efficient + + self.row_blocksize = row_blocksize + self.col_blocksize = col_blocksize diff --git a/llava/model/coat/activation/models/_fp8_weightcache.py b/llava/model/coat/activation/models/_fp8_weightcache.py new file mode 100644 index 0000000000000000000000000000000000000000..85b1c504c74b83299181020e1275133428f93e1d --- /dev/null +++ b/llava/model/coat/activation/models/_fp8_weightcache.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import torch.nn as nn + +from ..real_quantization import fp8_division_transpose + + +class FP8CacheWeightModule(nn.Module): + def __init__(self, config, qargs, layer_id): + super().__init__() + self.config = config + self.qargs = qargs + self.layer_id = layer_id + + def prepare_weight(self, weight, weight_name, is_first_microbatch): + if is_first_microbatch: + if self.qargs.weight_memory_efficient: + # print(f"{weight_name} uses first microbatch") + weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( + weight, self.qargs.group_size, self.fwobits["fwbit"] + ) + setattr(self, f"{weight_name}_fp8_scale", weight_s) + return weight_fp8, weight_fp8_t, weight_s + else: + # print(f"{weight_name} uses first microbatch") + weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( + weight, self.qargs.group_size, self.fwobits["fwbit"] + ) + setattr(self, f"{weight_name}_fp8", weight_fp8) + setattr(self, f"{weight_name}_fp8_t", weight_fp8_t) + setattr(self, f"{weight_name}_fp8_scale", weight_s) + return weight_fp8, weight_fp8_t, weight_s + else: + if self.qargs.weight_memory_efficient: + return getattr(self, f"{weight_name}_fp8_scale") + else: + return ( + getattr(self, f"{weight_name}_fp8"), + getattr(self, f"{weight_name}_fp8_t"), + getattr(self, f"{weight_name}_fp8_scale"), + ) + + def forward(self, x): + pass diff --git a/llava/model/coat/activation/models/_fp8manager.py b/llava/model/coat/activation/models/_fp8manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5270acf79adc7d17838959b0c679cba3b43c5bee --- /dev/null +++ b/llava/model/coat/activation/models/_fp8manager.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class FP8Manager: + """Class to keep track of and manipulate the global + FP8 state at different stages of execution. + """ + + is_first_microbatch = False diff --git a/llava/model/coat/activation/models/coat_llama.py b/llava/model/coat/activation/models/coat_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6763c027ac05d04c1b66b483cf886c0e1f7fab --- /dev/null +++ b/llava/model/coat/activation/models/coat_llama.py @@ -0,0 +1,1479 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from fnmatch import fnmatch +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaForCausalLM, + LlamaLinearScalingRotaryEmbedding, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + _prepare_4d_causal_attention_mask_with_cache_position, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) + +from ..real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) + +# FP8 related +from ._fp8_quantization_config import QuantizationConfig +from ._fp8_weightcache import FP8CacheWeightModule +from ._fp8manager import FP8Manager + +logger = logging.get_logger(__name__) + + +class CoatLlamaConfig(LlamaConfig): + model_type = "fp8_llama" + + +class CoatLlamaBeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_idx: Optional[int] = None): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + return _CoatLlamaBeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + None, + None, + weight1_s, + self.k_proj.weight, + None, + None, + weight2_s, + self.v_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _CoatLlamaBeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _CoatLlamaBeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, fp_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + re_g, + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_k_wg, + None, + None, + None, + att_v_wg, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class CoatLlamaAfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + + return _CoatLlamaAfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + None, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _CoatLlamaAfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _CoatLlamaAfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4 is None # memory efficient + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class CoatLlamaMLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch) + + return _CoatLlamaMLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + None, + None, + weight1_s, + self.up_proj.weight, + None, + None, + weight2_s, + self.down_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _CoatLlamaMLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _CoatLlamaMLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None and weight2 is None and weight3 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class LlamaAttentionWithoutLinear(nn.Module): + """ + Remove the Q/K/V/O projection layer in LlamaAttention module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2WithoutLinear(LlamaAttentionWithoutLinear): + """ + Remove the Q/K/V/O projection layer in LlamaFlashAttention2 module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = query_states.size() + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttentionWithoutLinear(LlamaAttentionWithoutLinear): + """ + Remove the Q/K/V/O projection layer in LlamaSdpaAttention module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + return attn_output, None, past_key_value + + +COAT_LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttentionWithoutLinear, + "flash_attention_2": LlamaFlashAttention2WithoutLinear, + "sdpa": LlamaSdpaAttentionWithoutLinear, +} + + +class CoatLlamaDecoderLayer(nn.Module): + def __init__(self, config: CoatLlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = COAT_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = CoatLlamaBeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = CoatLlamaAfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = CoatLlamaMLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): BF16 input to the layer of shape `(batch, seq_len, embed_dim)` + quant_hidden_states (`torch.float8_e4m3fn`): FP8 input to the layer of shape `(batch, seq_len, embed_dim)` + scale_hidden_states (`torch.bfloat16`): BF16 scaling factor to the layer of shape `(batch, seq_len, embed_dim // group_size)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual, query_states, key_states, value_states = self.BeforeAttention( + hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states) + + # Residual Connection, LayerNorm, and the whole MLP module + hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class CoatLlamaPreTrainedModel(PreTrainedModel): + config_class = CoatLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class CoatLlamaModel(CoatLlamaPreTrainedModel): + """ + Coat Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CoatLlamaDecoderLayer`] + + Args: + config: CoatLlamaConfig + """ + + def __init__(self, config: CoatLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [CoatLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + quant_hidden_states, + scale_hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + _update_causal_mask = LlamaModel._update_causal_mask + + +class CoatLlamaForCausalLM(CoatLlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = CoatLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +# TODO +# class LlamaForSequenceClassification(LlamaPreTrainedModel): + +# class LlamaForQuestionAnswering(LlamaPreTrainedModel): + +# class LlamaForTokenClassification(LlamaPreTrainedModel): + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict + + +AutoConfig.register("fp8_llama", CoatLlamaConfig) +AutoModel.register(CoatLlamaConfig, CoatLlamaModel) +AutoModelForCausalLM.register(CoatLlamaConfig, CoatLlamaForCausalLM) diff --git a/llava/model/coat/activation/models/coat_llama_convert_from_hf.py b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..c02da9e1b4211f8becaa2a54dc30a46fcf88185b --- /dev/null +++ b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import os +from dataclasses import asdict, dataclass, field +from typing import Optional + +import torch +import transformers +from coat.activation.models._fp8_quantization_config import QuantizationConfig +from coat.activation.models.coat_llama import CoatLlamaConfig, CoatLlamaForCausalLM, make_state_dict_compatible +from transformers import AutoConfig, AutoModelForCausalLM + + +@dataclass +class ConvertArguments: + model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"}) + save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"}) + cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"}) + + +def download_and_convert_llama(convert_args: ConvertArguments, quantization_args: QuantizationConfig): + """ + Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`, + and saves the converted model. + + Args: + model_name (str): The model name or path to download the LLaMA model. + save_path (str): The path where the converted model weights will be saved. + cache_dir (Optional[str]): Directory to cache the model. Defaults to None. + + Returns: + None + """ + model_name = convert_args.model_name + save_path = convert_args.save_path + cache_dir = convert_args.cache_dir + + # Step 1: Download the original LLaMA model + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 2: Initialize the model configuration for FP8 or other custom config + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 3: Apply make_state_dict_compatible to convert weights + compatible_state_dict = make_state_dict_compatible(model.state_dict()) + + # Step 4: Create a new model instance with compatible configuration + fp8_config = CoatLlamaConfig(**config.to_dict()) + fp8_config.coat_fp8_args = asdict(quantization_args) + + converted_model = AutoModelForCausalLM.from_config(fp8_config) + converted_model.load_state_dict(compatible_state_dict) + + # Step 5: Save the converted model and configuration using save_pretrained + os.makedirs(save_path, exist_ok=True) + converted_model.save_pretrained(save_path) + print(f"Converted model saved at {save_path}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8 + convert_args, quantization_args = parser.parse_args_into_dataclasses() + + # Call the function with parsed arguments + download_and_convert_llama(convert_args, quantization_args) diff --git a/llava/model/coat/activation/models/coat_olmo.py b/llava/model/coat/activation/models/coat_olmo.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d2cdedd73c56a68396aebdab1a2762876abc53 --- /dev/null +++ b/llava/model/coat/activation/models/coat_olmo.py @@ -0,0 +1,1942 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Adapted from +[MosaiclML](https://github.com/mosaicml/examples.git) and +[minGPT](https://github.com/karpathy/minGPT.git) +""" + +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple, cast + +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from olmo.aliases import PathOrStr +from olmo.beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler +from olmo.config import ( + ActivationCheckpointingStrategy, + ActivationType, + BlockType, + CheckpointType, + FSDPWrapStrategy, + InitFnType, + LayerNormType, + ModelConfig, + QuantActivationConfig, + ShardedCheckpointerType, + TrainConfig, +) +from olmo.exceptions import OLMoConfigurationError +from olmo.initialization import init_normal +from olmo.model import ( + Activation, + BufferCache, + Dropout, + LayerNorm, + LayerNormBase, + OLMo, + OLMoBlock, + OLMoBlockGroup, + OLMoGenerateOutput, + OLMoOutput, + RMSLayerNorm, + RotaryEmbedding, + _non_meta_init_device, + activation_checkpoint_function, + alibi_attention_bias, + causal_attention_bias, + get_causal_attention_bias, + should_checkpoint_block, +) +from olmo.torch_util import ensure_finite_, get_cumulative_document_lengths +from torch import einsum + +from ..real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ._fp8_weightcache import FP8CacheWeightModule +from ._fp8manager import FP8Manager + +if sys.version_info.minor > 8: + from collections.abc import MutableMapping +elif sys.version_info.minor == 8: + from typing import MutableMapping +else: + raise SystemExit("This script supports Python 3.8 or higher") + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "OLMoBlock", + "OLMoSequentialBlock", + "OLMo", + "OLMoOutput", + "OLMoGenerateOutput", +] + + +log = logging.getLogger(__name__) + + +class CoatOLMoBeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, fused_dims: tuple): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.ln_normalized_shape = config.d_model + self.att_proj = nn.Linear(config.d_model, sum(fused_dims), bias=config.include_bias, device=config.init_device) + + self.attn_norm = LayerNorm.build(config) + + def forward(self, re_x, x, s): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch) + return _CoatOLMoBeforeAttentionResidual.apply( + re_x, + x, + s, + self.att_proj.weight, + None, + None, + weight1_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch + ) + return _CoatOLMoBeforeAttentionResidual.apply( + re_x, + x, + s, + self.att_proj.weight, + weight1, + weight1_t, + weight1_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _CoatOLMoBeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward( + in_x, in_s, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None + ctx.weight = weight1_origin, weight1_s + else: + ctx.weight = weight1_t, weight1_s + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x + + @staticmethod + def backward(ctx, fp_grad, flash_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s = ctx.weight + group_size = ctx.group_size + mean, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + flash_g, flash_gs, flash_g_t = fp8_quantize_pertensor_transpose( + flash_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + fc1_g, att_proj_wg = fp8_linear_backward( + ln_x_t, ln_s, flash_g, flash_gs, flash_g_t, weight1_t, weight1_s, group_size + ) + + # LayerNorm + in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return re_g, in_g, in_sg_g16, att_proj_wg, None, None, None, None, None, None, None, None, None + + +class CoatOLMoAfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.attn_out = nn.Linear(config.d_model, config.d_model, bias=config.include_bias, device=config.init_device) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight2_s = self.prepare_weight(self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch) + + return _CoatOLMoAfterAttentionResidual.apply( + re_x, + in_x, + self.attn_out.weight, + None, + None, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight2, weight2_t, weight2_s = self.prepare_weight( + self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch + ) + + return _CoatOLMoAfterAttentionResidual.apply( + re_x, + in_x, + self.attn_out.weight, + weight2, + weight2_t, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _CoatOLMoAfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight2_origin, weight2, weight2_t, weight2_s, group_size, fwobits, layer_id, config, qargs + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight2 is None # memory efficient + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + fc2_x = fp8_linear_forward(flash_qx, flash_s, weight2, weight2_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight2_t is None + ctx.weight = weight2_origin, weight2_s + else: + ctx.weight = weight2_t, weight2_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight2_t, weight2_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + fc2_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size + ) + + return fp_grad, fc2_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class CoatOLMoMLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.ln_normalized_shape = config.d_model + self.act_output_multiplier = 0.5 if config.activation_type == ActivationType.swiglu else 1 + self.ff_proj = nn.Linear(config.d_model, hidden_size, bias=config.include_bias, device=config.init_device) + self.ff_out = nn.Linear( + int(self.act_output_multiplier * hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.training = True + + # below is only used when training = False + self.ff_norm = LayerNorm.build(config) + self.act = Activation.build(config) + assert (self.act.output_multiplier * hidden_size) % 1 == 0 + + def forward(self, re_x, x, s): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch) + + return _CoatOLMoMLPResidual.apply( + re_x, + x, + s, + self.ff_proj.weight, + None, + None, + weight1_s, + self.ff_out.weight, + None, + None, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch + ) + + return _CoatOLMoMLPResidual.apply( + re_x, + x, + s, + self.ff_proj.weight, + weight1, + weight1_t, + weight1_s, + self.ff_out.weight, + weight2, + weight2_t, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _CoatOLMoMLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward( + in_x, in_s, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + fc1_x, fc1_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) + + # NOTE: Becareful of the order + up_x, gate_x = fc1_x.chunk(2, dim=-1) + up_s, gate_s = fc1_s.chunk(2, dim=-1) + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight2 is None: # memory efficient + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + fc2_x = fp8_linear_forward(mul_x, mul_s, weight2, weight2_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s) = ctx.weight + group_size = ctx.group_size + mean, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + fc2_g, weight2_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight2_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2) = fp8_mul_backward(silu_x, silu_s, up_x, up_s, fc2_g, group_size, fwobits["babit"]) + + # Silu activation + silu_g, silu_gs = fp8_silu_backward(gate_x, gate_s, mul_g1, group_size, fwobits["babit"]) + + # Prepare the input of Linear Layer. NOTE: Becareful of the order + gateup_g = torch.cat([mul_g2, silu_g], dim=-1) + gateup_gs = torch.cat([mul_gs2, silu_gs]) + gateup_gs = torch.max(gateup_gs) + + gateup_g, gateup_gs, gateup_g_t = fp8_division_transpose( + gateup_g, group_size, fwobits["babit"], gateup_gs, stochastic=False + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, gateup_g, gateup_gs, gateup_g_t, weight1_t, weight1_s, group_size + ) + + # layerNorm + in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class CoatOLMoBlock(nn.Module): + """ + A base class for transformer block implementations. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.qargs = qargs + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn: Callable | None = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: LayerNormBase | None = None + self.q_norm: LayerNormBase | None = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + if not self.qargs.use_quantize_model: + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + self.flash_attn_varlen_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore + + self.flash_attn_func = flash_attn_func + self.flash_attn_varlen_func = flash_attn_varlen_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + + if not self.qargs.use_quantize_model: + if self.config.init_fn == InitFnType.normal: + attn_out_std = ff_out_std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + + elif self.config.init_fn == InitFnType.mitchell: + attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) + ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1))) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + elif self.config.init_fn == InitFnType.full_megatron: + attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) + init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) + + def set_activation_checkpointing( + self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None + ): + if strategy == ActivationCheckpointingStrategy.fine_grained: + self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config) + else: + self._activation_checkpoint_fn = None + + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) + return bias + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + """ + if max_doc_len is not None and cu_doc_lens is not None: + assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking" + assert attn_mask is None, "attn-mask is currently not supported with document masking" + B, T, D = q.size(0), q.size(2), q.size(3) + r = self.flash_attn_varlen_func( + q.transpose(1, 2).view(B * T, -1, D), + k.transpose(1, 2).view(B * T, -1, D), + v.transpose(1, 2).view(B * T, -1, D), + cu_doc_lens, + cu_doc_lens, + max_doc_len, + max_doc_len, + dropout_p=dropout_p, + causal=is_causal, + ) + return r.view(B, T, -1, D).transpose(1, 2) + elif self.flash_attn_func is not None and attn_mask is None: + r = self.flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal + ) + return r.transpose(1, 2) + else: + # torch's sdpa doesn't support GQA, so we're doing this + assert k.size(1) == v.size(1) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + present = (k, v) if use_cache else None + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + + if self.config.rope: + # Apply rotary embeddings. + q, k = self.rotary_emb(q, k) + + if attention_bias is not None: + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = self._scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=attention_bias is None, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. NOTE: We move the attn output outside of this attention function + return att, present + + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: torch.FloatTensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + raise NotImplementedError + + @classmethod + def build(cls, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache) -> OLMoBlock: + if config.block_type == BlockType.sequential: + return CoatOLMoSequentialBlock(layer_id, config, qargs, cache) + elif config.block_type == BlockType.llama: + return CoatOLMoLlamaBlock(layer_id, config, qargs, cache) + else: + raise NotImplementedError(f"Unknown block type: '{config.block_type}'") + + +class CoatOLMoSequentialBlock(CoatOLMoBlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``, + use the flag `norm_after`. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__(layer_id, config, qargs, cache) + # Attention input projection. Projects x -> (q, k, v) + + assert not self.config.norm_after, "COAT currently does not support PostNorm" + + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + + if self.qargs.use_quantize_model: + self.BeforeAttention = CoatOLMoBeforeAttentionResidual(config, qargs, self.layer_id, self.fused_dims) + self.AfterAttention = CoatOLMoAfterAttentionResidual(config, qargs, self.layer_id) + self.MLPResidual = CoatOLMoMLPResidual(config, qargs, self.layer_id, self.hidden_size) + else: + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device + ) + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + # Layer norms. + self.attn_norm = LayerNorm.build(config, size=config.d_model) + self.ff_norm = LayerNorm.build(config, size=config.d_model) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.qargs.use_quantize_model: # The initialization appears here, not in CoatOLMoBlock's reset_parameters + if self.config.init_fn == InitFnType.normal: + attn_out_std = ff_out_std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + + elif self.config.init_fn == InitFnType.mitchell: + attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) + ff_out_std = 1 / (math.sqrt(2 * self.MLPResidual.ff_out.in_features * (self.layer_id + 1))) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + elif self.config.init_fn == InitFnType.full_megatron: + attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.AfterAttention.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) + init_normal(self.MLPResidual.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + if not self.qargs.use_quantize_model: + init_normal(self.att_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + else: + init_normal(self.BeforeAttention.att_proj, std, cutoff_factor) + init_normal(self.MLPResidual.ff_proj, std, cutoff_factor) + + def forward( + self, + x: torch.Tensor, + qx: torch.Tensor, + sx: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + + # import IPython + # IPython.embed() + + if self.qargs.use_quantize_model: + # if False: + x, qkv = self.BeforeAttention(x, qx, sx) + else: + # apply norm before + h = self.attn_norm(x) + + qkv = self.BeforeAttention.att_proj(h) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + att, cache = self.attention( + q, + k, + v, + attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + # import IPython + # IPython.embed() + if self.qargs.use_quantize_model: + # if False: + x, qx, sx = self.AfterAttention(x, att) + else: + att = self.AfterAttention.attn_out(att) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + if self.qargs.use_quantize_model: + # if False: + x, qx, sx = self.MLPResidual(x, qx, sx) + else: + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + + x = self.ff_norm(x) + + x = self.MLPResidual.ff_proj(x) + + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.MLPResidual.ff_out(x) + + x = self.dropout(x) + x = og_x + x + + # import IPython + # IPython.embed() + + return x, qx, sx, cache + + +class CoatOLMoLlamaBlock(OLMoBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `OLMoSequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__(layer_id, config, qargs, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + if config.multi_query_attention: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model // config.n_heads + v_proj_out_dim = config.d_model // config.n_heads + else: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model + v_proj_out_dim = config.d_model + self.q_proj = nn.Linear(config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device) + self.k_proj = nn.Linear(config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device) + self.v_proj = nn.Linear(config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device) + + # Feed-forward input projection. + self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.q_proj, std, cutoff_factor) + init_normal(self.k_proj, std, cutoff_factor) + init_normal(self.v_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> torch.Tensor: + if max_doc_len is not None or cu_doc_lens is not None: + raise NotImplementedError(f"attention document masking is not implemented for {self.__class__.__name__}") + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if is_causal: + assert attn_mask is None + + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] + elif attn_mask is not None: + attn_bias = attn_mask.to(q.dtype) + else: + attn_bias = torch.zeros_like(attn_weights) + + attn_weights += attn_bias + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) + return torch.matmul(attn_weights, v) + + def forward( + self, + x: torch.Tensor, + qx: torch.Tensor, + sx: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + if self.config.clip_qkv is not None: + q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Get attention scores. + att, cache = self.attention( + q, + k, + v, + attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + att = self.attn_out(att) # NOTE: we move the attn_out outside the self.attention module + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class CoatOLMoBlockGroup(nn.ModuleList): + def __init__(self, config: ModelConfig, layer_offset: int, modules: Iterable[nn.Module] | None = None): + super().__init__(modules) + self.config = config + self.layer_offset = layer_offset + self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: torch.FloatTensor | None = None, + layers_past: list[tuple[torch.Tensor, torch.Tensor]] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]] | None]: + attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None + for block_idx, block in enumerate(self): + layer_past = None if layers_past is None else layers_past[block_idx] + block_idx += self.layer_offset + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( # type: ignore + block, + x, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block( + x, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + return x, attn_key_values + + def reset_parameters(self): + for block in self: + block.reset_parameters() + + def set_activation_checkpointing( + self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None + ): + self.activation_checkpointing_strategy = strategy + for block in self: + block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func) + + +class CoatOLMo(nn.Module): + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, init_params: bool = True): + super().__init__() + self.config = config + self.qargs = qargs + self.__cache = BufferCache() + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive") + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + raise OLMoConfigurationError("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) + + if not ( + 0 < self.config.block_group_size <= self.config.n_layers + and self.config.n_layers % self.config.block_group_size == 0 + ): + raise OLMoConfigurationError("n layers must be divisible by block group size") + + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device), + emb_drop=Dropout(config.embedding_dropout), + ln_f=LayerNorm.build(config), + ) + ) + + blocks = [CoatOLMoBlock.build(i, config, qargs, self.__cache) for i in range(config.n_layers)] + if self.config.block_group_size > 1: + block_groups = [ + CoatOLMoBlockGroup(config, i, blocks[i : i + config.block_group_size]) + for i in range(0, config.n_layers, config.block_group_size) + ] + self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not (self.config.alibi or self.config.rope): + self.transformer.update( + {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} + ) + if not config.weight_tying: + self.transformer.update( + { + "ff_out": nn.Linear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + device=config.init_device, + ) + } + ) + if config.embedding_layer_norm: + self.transformer.update({"emb_norm": LayerNorm.build(config)}) + + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. + if init_params and self.config.init_device != "meta": + self.reset_parameters() + self.__num_fwd_flops: int | None = None + self.__num_bck_flops: int | None = None + + # Warm up cache. + if self.config.alibi: + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) + self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + + # Quantize + self.quantize_input_before_block = Coat_quantize_bgn(qargs) + self.quantize_output_after_block = Coat_quantize_end(qargs) + + set_activation_checkpointing = OLMo.set_activation_checkpointing + device = OLMo.device + reset_parameters = OLMo.reset_parameters + get_alibi_attention_bias = OLMo.get_alibi_attention_bias + + def forward( + self, + input_ids: torch.LongTensor, + input_embeddings: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + attention_bias: torch.Tensor | None = None, + past_key_values: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: bool | None = None, + doc_lens: torch.Tensor | None = None, + max_doc_lens: Sequence[int] | None = None, + ) -> OLMoOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + :param doc_lens: Document lengths to use in attention for intra-document masking. + Shape `(batch_size, max_docs)`. + :param max_doc_lens: Maximum document length for each instance in the batch. + """ + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + max_doc_len: int | None = None + cu_doc_lens: torch.Tensor | None = None + if doc_lens is not None and max_doc_lens is not None: + max_doc_len = max(max_doc_lens) + cu_doc_lens = get_cumulative_document_lengths(doc_lens) + + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + # Apply embedding layer norm. + if self.config.embedding_layer_norm: + x = self.transformer.emb_norm(x) + + if not (self.config.alibi or self.config.rope): + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # Transform the attention mask into what the blocks expect. + if attention_mask is not None: + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + + # Merge attention mask with attention bias. + if ( + attention_bias is not None + or attention_mask is not None + or self.config.alibi + # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly + # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute + # scores correctly. + or past_key_values is not None + ): + if attention_bias is None and self.config.alibi: + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device + ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) + elif attention_bias is None: + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) + elif attention_bias.dtype in (torch.int8, torch.bool): + attention_bias = attention_bias.to(dtype=torch.float) + attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) + + # Transform to the right shape and data type. + mask_len = seq_len + if attention_mask is not None: + mask_len = attention_mask.shape[-1] + elif past_key_values is not None: + mask_len = past_key_values[0][0].shape[-2] + seq_len + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) + + # Add in the masking bias. + if attention_mask is not None: + attention_bias = attention_bias + attention_mask + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead + # it can produce NaNs. + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) + + attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None + + # decoder layers + all_hidden_states = [] + + # Prepare the input for COAT decoderlayer + x, qx, sx = self.quantize_input_before_block(x) + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values[block_idx] + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, qx, sx, cache = self._activation_checkpoint_fn( + block, + x, + qx, + sx, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + else: + # shape: (batch_size, seq_len, d_model) + x, qx, sx, cache = block( + x, + qx, + sx, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, cache = block_group( + x, + attention_bias=attention_bias, + layers_past=layers_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.extend(cache) + + # Summarize the output of the Decoder Layer + x = self.quantize_output_after_block(x, qx, sx) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + + return OLMoOutput( + logits=logits, + attn_key_values=attn_key_values, + hidden_states=tuple(all_hidden_states) if output_hidden_states else None, + ) + + def get_fsdp_wrap_policy(self, wrap_strategy: FSDPWrapStrategy | None = None): + if wrap_strategy is None: + return None + + # The 'recurse' mode for the wrap function does not behave like you'd expect. + # Even if we return False, it may still recurse because PyTorch does what it wants, + # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer) + # but not other linear layers within a block. + # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just + # return True in 'recurse' mode for simplicity. + size_based_module_to_wrap = {self.transformer.wte} + if hasattr(self.transformer, "ff_out"): + size_based_module_to_wrap.add(self.transformer.ff_out) + + if wrap_strategy == FSDPWrapStrategy.by_block: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlock) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (CoatOLMoBlock,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlockGroup) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (CoatOLMoBlockGroup,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.size_based: + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + return size_based_auto_wrap_policy + elif wrap_strategy in { + FSDPWrapStrategy.one_in_two, + FSDPWrapStrategy.one_in_three, + FSDPWrapStrategy.one_in_four, + FSDPWrapStrategy.one_in_five, + }: + c = { + FSDPWrapStrategy.one_in_two: 2, + FSDPWrapStrategy.one_in_three: 3, + FSDPWrapStrategy.one_in_four: 4, + FSDPWrapStrategy.one_in_five: 5, + }[wrap_strategy] + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlock) and module.layer_id % c == 0 + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + else: + raise NotImplementedError(wrap_strategy) + + num_params = OLMo.num_params + + @property + def num_fwd_flops(self): + if self.__num_fwd_flops: + return self.__num_fwd_flops + + # embedding table is just a lookup in the forward pass + n_params = self.num_params(include_embedding=False) + # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network + # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param + # this gets us FLOPs / token + params_flops_per_token = 2 * n_params + # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2) + attn_flops_per_token = self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length) + self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token + return self.__num_fwd_flops + + @property + def num_bck_flops(self): + if self.__num_bck_flops: + return self.__num_bck_flops + + n_params = self.num_params() + params_flops_per_token = 4 * n_params + attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length) + self.__num_bck_flops = params_flops_per_token + attn_flops_per_token + return self.__num_bck_flops + + generate = OLMo.generate + + @classmethod + def from_checkpoint( + cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: CheckpointType | None = None + ) -> CoatOLMo: + """ + Load an OLMo model from a checkpoint. + """ + from olmo.util import resource_path + + # Guess checkpoint type. + if checkpoint_type is None: + try: + if resource_path(checkpoint_dir, "model.pt").is_file(): + checkpoint_type = CheckpointType.unsharded + else: + checkpoint_type = CheckpointType.sharded + except FileNotFoundError: + checkpoint_type = CheckpointType.sharded + + # Load config. + config_path = resource_path(checkpoint_dir, "config.yaml") + model_config = ModelConfig.load(config_path, key="model", validate_paths=False) + + if checkpoint_type == CheckpointType.unsharded: + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + model = CoatOLMo(model_config) + + # Load state dict directly to target device. + state_dict_path = resource_path(checkpoint_dir, "model.pt") + state_dict = torch.load(state_dict_path, map_location="cpu") + model.load_state_dict(model._make_state_dict_compatible(state_dict)[0]) + model = model.to(torch.device(device)) + else: + train_config = TrainConfig.load(config_path) + if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core: + from olmo_core.distributed.checkpoint import load_model_and_optim_state # type: ignore + + model_config.init_device = device + model = CoatOLMo(model_config) + load_model_and_optim_state(checkpoint_dir, model) + else: + # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new + from olmo.checkpoint import load_model_state + + # Initialize model on target device. In this case the state dict is loaded in-place + # so it's not necessary to start on CPU if the target device is a GPU. + model_config.init_device = device + model = CoatOLMo(model_config) + + # Load state dict in place. + load_model_state(checkpoint_dir, model) + + return model.eval() + + def _make_state_dict_compatible( + self, state_dict: dict[str, torch.Tensor] + ) -> tuple[dict[str, torch.Tensor], dict[str, set[str]]]: + """ + Handles some cases where the state dict is valid yet may need to be transformed in order to + be loaded. + + This modifies the state dict in-place and also returns it, along with a mapping of original key + names to new key names in cases where the keys were simply renamed. That mapping can be used + to make a corresponding optimizer state dict compatible as well. + """ + import re + from fnmatch import fnmatch + + new_keys_to_og_keys: dict[str, str] = {} + + # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is + # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work + # fine without the prefixes. This also simplifies the other steps below. + for key in list(state_dict.keys()): + state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = key + + # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 + if self.config.block_type == BlockType.sequential: + for key in list(state_dict.keys()): + if fnmatch(key, "transformer.*.norm.weight"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.*.norm.bias"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + + # Realquantization will change the place the linear layers happen + if self.qargs.use_quantize_model == "coat_real": + for key in list(state_dict.keys()): + if fnmatch(key, "transformer.blocks.*.att_proj.weight") and "BeforeAttention" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("att_proj.weight", "BeforeAttention.att_proj.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.attn_out.weight") and "AfterAttention" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("attn_out.weight", "AfterAttention.attn_out.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.ff_proj.weight") and "MLPResidual" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("ff_proj.weight", "MLPResidual.ff_proj.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.ff_out.weight") and "MLPResidual" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("ff_out.weight", "MLPResidual.ff_out.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + + # For loading a state dict that was saved with a different `block_group_size`. + if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys(): + state_dict_block_group_size = len( + [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")] + ) + else: + state_dict_block_group_size = 1 + if self.config.block_group_size != state_dict_block_group_size: + log.info( + f"Regrouping state dict blocks from group size {state_dict_block_group_size} to " + f"group size {self.config.block_group_size}" + ) + # For simplicity we're first going to flatten out the block groups in the state dict (if necessary) + # and then (re-)group them into the right block sizes. + if state_dict_block_group_size > 1: + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None: + group_idx, group_block_idx = int(m.group(1)), int(m.group(2)) + block_idx = (group_idx * state_dict_block_group_size) + group_block_idx + state_dict[ + ( + new_key := key.replace( + f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + if self.config.block_group_size > 1: + # Group the state dict blocks into the right block size. + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None: + block_idx = int(m.group(1)) + group_idx, group_block_idx = ( + block_idx // self.config.block_group_size, + block_idx % self.config.block_group_size, + ) + state_dict[ + ( + new_key := key.replace( + f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + og_keys_to_new: dict[str, set[str]] = defaultdict(set) + for new_key, og_key in new_keys_to_og_keys.items(): + og_keys_to_new[og_key].add(new_key) + + return state_dict, og_keys_to_new diff --git a/llava/model/coat/activation/real_quantization/__init__.py b/llava/model/coat/activation/real_quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61875efb5d6a495fdab36438ea99a8784a88c4cf --- /dev/null +++ b/llava/model/coat/activation/real_quantization/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Activation +# Utils +from ._dequantize import fp8_dequantize +from ._division import fp8_division +from ._division_transpose import fp8_division_transpose +from ._quantize import fp8_quantize +from ._quantize_pertensor import fp8_quantize_pertensor +from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose +from ._transpose import fp8_transpose +from .add_bwd import fp8_add_Ifp_Ifp_Ofp_Opt +from .add_fwd import fp8_add_Ifp_Ifp_Ofp_Og16 + +# Normalization +from .func_layernorm_noparam import fp8_layernorm_noparam_backward, fp8_layernorm_noparam_forward +from .func_quantize import Coat_quantize_bgn, Coat_quantize_end +from .func_rmsnorm import fp8_rmsnorm_backward, fp8_rmsnorm_forward +from .gelu_bwd import fp8_gelu_backward +from .gelu_fwd import fp8_gelu_forward + +# linear and add +from .linear import fp8_linear_backward, fp8_linear_forward +from .mul_bwd import fp8_mul_backward +from .mul_fwd import fp8_mul_forward +from .silu_bwd import fp8_silu_backward +from .silu_fwd import fp8_silu_forward diff --git a/llava/model/coat/activation/real_quantization/_dequantize.py b/llava/model/coat/activation/real_quantization/_dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4347932f9c3691ffada9d59258337aac75b74942 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_dequantize.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_dequantize_kernel( + output_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of output stride + output_stride_0, + output_stride_1, # output stride + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + output = input * scale_input + + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + output = output.to(output_ptr.dtype.element_ty) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def fp8_dequantize(x, s_x, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + y = torch.empty_like(x, dtype=torch.bfloat16) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_dequantize_kernel[grid]( + y, + x, + s_x, + M, + N, + SN, + QB, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y diff --git a/llava/model/coat/activation/real_quantization/_division.py b/llava/model/coat/activation/real_quantization/_division.py new file mode 100644 index 0000000000000000000000000000000000000000..335ff1a67795dc858a740c6618466ff79fb554cb --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_division.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Quantize and Transpose Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_division_kernel( + output_ptr, # output + input_ptr, + input_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit: tl.constexpr, + m_bit: tl.constexpr, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + scale_output = tl.load(input_scale_ptr) + scale_output = scale_output.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + # Quantize + output = tl.fdiv(output, scale_output) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input_stride_0, input_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1 + noise = tl.rand(0, noise_offset) + + output = _stochastic_rounding(output, noise, e_bit, m_bit) + + output = output.to(output_ptr.type.element_ty) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +@triton.jit +def _stochastic_rounding(output, noise, e_bit: tl.constexpr, m_bit: tl.constexpr): + subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit) + # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1) + + output_int32 = tl.cast(output, tl.int32, bitcast=True) + output_int32 = output_int32 & 0x7F800000 + output_float32 = tl.cast(output_int32, tl.float32, bitcast=True) + output_exp = tl.maximum(output_float32, subnormal_min) + + noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * ( + 1 - tl.exp2(m_bit) + ) # 2^m_bit for normal, 1 for subnormal + + noise = output_exp * noise / noise_rescale + sign = 1 - 2 * libdevice.signbit(output) + output = tl.abs(output) + noise + + minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal + output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp) + + return output + + +def fp8_division(x, QB, fp8type, s_y=None, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + if stochastic: + # noise = torch.zeros_like(x, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(x, dtype=fp8type) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + if s_y is None: + s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_division_kernel[grid]( + y, + x, + s_y, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y, s_y # y_t is expected to be 2D tensor diff --git a/llava/model/coat/activation/real_quantization/_division_transpose.py b/llava/model/coat/activation/real_quantization/_division_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbabaa31e4f47482c98d69431a9dfae488a2f15 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_division_transpose.py @@ -0,0 +1,215 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import _stochastic_rounding +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Division and Transpose Operator""" +"""Input uses full-precision/BF16""" +"""Output uses per tensor quantization""" +"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,) + # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], # + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_division_transpose_kernel( + output_ptr, + output_t_ptr, # output + input_ptr, + input_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + output_t_stride_0, + output_t_stride_1, # output stride + SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor + STOCHASTIC: tl.constexpr, + ONLY_TRANSPOSED: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + scale_output = tl.load(input_scale_ptr) + scale_output = scale_output.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + # Quantize + output = tl.fdiv(output, scale_output) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input_stride_0, input_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1 + noise = tl.rand(0, noise_offset) + + output = _stochastic_rounding(output, noise, e_bit, m_bit) + + output = output.to(output_ptr.type.element_ty) + # tl.device_print("3: ", output) + output_t = tl.trans(output) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + output_t_block_ptr = tl.make_block_ptr( + base=output_t_ptr, + shape=(N, M), + strides=(output_t_stride_0, output_t_stride_1), + offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M), + block_shape=(BLOCK_N, BLOCK_M), + order=(1, 0), + ) + if not ONLY_TRANSPOSED: + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1)) + + +def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + if stochastic: + # noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(x, dtype=fp8type) + y_t = torch.empty((N, M), dtype=fp8type, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + if s_y is None: + # print("Warning: do not specify s_y in fp8_division_transpose") + s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_division_transpose_kernel[grid]( + y, + y_t, + x, + s_y, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + y_t.stride(0), + y_t.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ONLY_TRANSPOSED=only_transposed, + ) + + if not only_transposed: + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y, s_y, y_t # y_t is expected to be 2D tensor + else: + return y_t, s_y diff --git a/llava/model/coat/activation/real_quantization/_memory_io.py b/llava/model/coat/activation/real_quantization/_memory_io.py new file mode 100644 index 0000000000000000000000000000000000000000..b04d2a7afcf749b634cbd5d88c6d99b42030d5ff --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_memory_io.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +CONST_BLOCK = 32 + +# The kernel with 1 load operation and 4 store operation +def get_configs_io_block(): + configs = [] + for nstages in [3, 4, 5, 6]: + for block_m in [32, 64, 128]: + for block_n in [32, 64, 128]: + for nwarps in [4, 8, 16, 32]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.jit +def bench_memory_io_kernel_forward( + output_ptr, + input_ptr, + M, + N, + B: tl.constexpr, + input_stride_0, + input_stride_1, + output_stride_0, + output_stride_1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_M = tl.cdiv(M, BLOCK_M) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = input * 2 + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + output = output.to(output_ptr.type.element_ty) + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def bench_memory_io_forward(x, B): + # defining the input and output tensor + M, N = x.shape + + y = torch.empty_like(x, dtype=x.dtype) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + bench_memory_io_kernel_forward[grid]( + y, + x, + M, + N, + B, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + ) + return y + + +configs = [] +for SL in [8192]: + configs.append( + triton.testing.Benchmark( # test different matrix size influence + x_names=["CDIM"], + x_vals=[1024, 2048, 4096, 8192], + line_arg="dtype", + line_vals=[torch.int8, torch.float16, torch.float32], + line_names=["float8", "float16", "float32"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="time-cost", + plot_name=f"INT8GELU", + args={"SL": SL, "B": CONST_BLOCK, "provider": "triton", "mode": "time-consuming"}, + ) + ) + + +@triton.testing.perf_report(configs) +def bench_load_store( + SL, CDIM, B, provider, dtype, mode="forward" +): # I only use triton as the provider, and mode when benchmarking + # create data + x = torch.randn(SL, CDIM, dtype=torch.float32).cuda() + x = x.to(dtype) + + quantiles = [0.5, 0.2, 0.8] + # utility functions + if provider == "triton": + + def y_fwd(): + bench_memory_io_forward(x, B) + + if provider == "torch": + torch_gelu = torch.nn.GELU() + + def y_fwd(): + return torch_gelu(x) + + # forward pass + if mode == "time-consuming": + convert_func = lambda ms: ms + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100) + # backward pass + if mode == "gbps": + convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100) + return convert_func(ms), convert_func(max_ms), convert_func(min_ms) + + +if __name__ == "__main__": + torch.manual_seed(0) + torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3) + bench_load_store.run(print_data=True) diff --git a/llava/model/coat/activation/real_quantization/_quantize.py b/llava/model/coat/activation/real_quantization/_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e1385b7dc42f25f0505a42cfecad48d9fa6eaa51 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + output = tl.fdiv(output, scale_output) + + output = output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize(x, QB, fp8type): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y = torch.empty_like(x, dtype=fp8type) + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_kernel[grid]( + y, + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py new file mode 100644 index 0000000000000000000000000000000000000000..46cadbc6e77e2e34c1fafec968aa17c5621da034 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py @@ -0,0 +1,152 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Per Tensor Quantize Operator""" +"""Input uses full precision""" +"""Output uses per tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_pertensor_kernel( + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize_pertensor(x, QB, fp8type, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + fp8type = convert_str_to_fp8[fp8type] + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_pertensor_kernel[grid]( + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + y, s_y_max = fp8_division(x, QB, fp8type, s_y_max, stochastic=stochastic) # reuse the floating point output y1 + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y_max, s_y diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..4e91d01445c99928aa7f79227cdbc7706fee1981 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Per Tensor Quantize and Transpose Operator""" +"""Input uses floating point tensor""" +"""Output uses per-tensor quantization, returns a non-transpose version and a transpose version""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_pertensor_transpose_kernel( + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize_pertensor_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + fp8type = convert_str_to_fp8[fp8type] + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_pertensor_transpose_kernel[grid]( + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose( + x, QB, fp8type, s_y_max, stochastic=stochastic + ) # Stochastic Rounding happens here + + # Recover 2D to 3D + if batched: + qy = qy.reshape(BS, -1, qy.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1]) + + return qy, s_y_max, qy_t # y_t is expected to be 2D tensor diff --git a/llava/model/coat/activation/real_quantization/_transpose.py b/llava/model/coat/activation/real_quantization/_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0b8ec2dc499d2bcd661e619d9f54e91daeac0c --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_transpose.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.jit +def _fp8_transpose_kernel( + output_ptr, # output + input_ptr, # input + M, + N, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + + output = tl.trans(input) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(N, M), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M), + block_shape=(BLOCK_N, BLOCK_M), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def fp8_transpose(x, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + + y = torch.empty((N, M), dtype=x.dtype, device=x.device) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_transpose_kernel[grid]( + y, + x, + M, + N, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + ) + + # Recover 2D to 3D + if batched and not transpose_output_2d: + y = y.reshape(BS, -1, y.shape[-1]) + + return y diff --git a/llava/model/coat/activation/real_quantization/add_bwd.py b/llava/model/coat/activation/real_quantization/add_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..5951afec125048d8549fd1be0ae204df1a6b222a --- /dev/null +++ b/llava/model/coat/activation/real_quantization/add_bwd.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Add, useful for backward""" +"""Input1 (Residual) uses full-precision/BF16""" +"""Input2 (Backbone) uses full-precision/BF16""" +"""Output1 uses full-precision/BF16""" +"""Output2 uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_add_Ifp_Ifp_Ofp_Opt_kernel( + output1_ptr, # output + output2_scale_ptr, + input1_ptr, # input + input2_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + output1_stride_0, + output1_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + input1 = input1.to(tl.float32) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + input2 = input2.to(tl.float32) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Add + add_output = input1 + input2 + + # Quantize the grad 1 - Scale calculation + abs_add_output = tl.abs(add_output) + max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES + scale_output2 = max_val / fp8_max + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1)) + + # save the fp add output + fp_add_output = add_output.to(output1_ptr.type.element_ty) + fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output1_block_ptr, fp_add_output, boundary_check=(0, 1)) + + # Quantize + scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty) + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN)) + + # pointers + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1)) + + +def fp8_add_Ifp_Ifp_Ofp_Opt(x1, x2, QB, fp8type, stochastic=False): # suppose x1 is full precision or BF16 + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(x2.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + SN = N // QB + assert x1.shape == x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(x1, dtype=torch.bfloat16) + s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_add_Ifp_Ifp_Ofp_Opt_kernel[grid]( + y1, + s_y2, + x1, + x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + x2.stride(0), + x2.stride(1), + y1.stride(0), + y1.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y2_max = s_y2.max() + qy2, s_y2_max = fp8_division(y1, QB, fp8type, s_y2_max, stochastic=stochastic) # reuse the floating point output y1 + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, s_y2) diff --git a/llava/model/coat/activation/real_quantization/add_fwd.py b/llava/model/coat/activation/real_quantization/add_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..90aed0b3573e4c1a21b4eecea959f6f537f61be8 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/add_fwd.py @@ -0,0 +1,219 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Element-wise Add, used in forward pass""" +"""Input1 (Residual) uses full-precision/BF16""" +"""Input2 (Backbone) uses full-precision/BF16""" +"""Output1 uses full-precision/BF16""" +"""Output2 uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_add_Ifp_Ifp_Ofp_Og16_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, + input1_ptr, # input + input2_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + input1 = input1.to(tl.float32) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + input2 = input2.to(tl.float32) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Add + add_output = input1 + input2 + + # Quantize the grad 1 - Scale calculation + abs_add_output = tl.abs(add_output) + max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES + scale_output2 = max_val / fp8_max + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1)) + + # save the fp add output + fp_add_output = add_output.to(output1_ptr.type.element_ty) + fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output1_block_ptr, fp_add_output) + + # Quantize + add_output = tl.fdiv(add_output, scale_output2) + scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty) + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN)) + add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N)) + + add_output = add_output.to(output2_ptr.type.element_ty) + add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, add_output, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1)) + + +def fp8_add_Ifp_Ifp_Ofp_Og16(x1, x2, fp8type, QB): # suppose x1 is full precision or BF16 + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + SN = int(N / QB) # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + + y1 = torch.empty_like(x1, dtype=torch.bfloat16) + y2 = torch.empty_like(x2, dtype=fp8type) + s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_add_Ifp_Ifp_Ofp_Og16_kernel[grid]( + y1, + y2, + s_y2, + x1, + x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + x2.stride(0), + x2.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) diff --git a/llava/model/coat/activation/real_quantization/common.py b/llava/model/coat/activation/real_quantization/common.py new file mode 100644 index 0000000000000000000000000000000000000000..00de715cf3e80eddf74ea401e1bbc9811c3e4970 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/common.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton + +SCALE_MIN_THRES = 1e-10 + +FP8_MAX_VALUE = { + torch.float8_e4m3fn: 448.0, + torch.float8_e5m2: 57344.0, +} + +convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} +convert_fp8_to_embit = { + torch.float8_e4m3fn: (4.0, 3.0), + torch.float8_e5m2: (5.0, 2.0), +} + + +def get_configs_io_block(): + configs = [] + for nstages in [3, 4, 5]: + for block_m in [32, 64]: + for block_n in [64, 128]: + for nwarps in [4, 8]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +# from .common import SCALE_MIN_THRES, FP8_MAX_VALUE +# SCALE_MIN_THRES: tl.constexpr, +# + SCALE_MIN_THRES +# SCALE_MIN_THRES=SCALE_MIN_THRES, diff --git a/llava/model/coat/activation/real_quantization/fp8linear.py b/llava/model/coat/activation/real_quantization/fp8linear.py new file mode 100644 index 0000000000000000000000000000000000000000..32531fcb4bb63e4720e3fc48fb08b1f2d82d54db --- /dev/null +++ b/llava/model/coat/activation/real_quantization/fp8linear.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +from copy import deepcopy +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.function import Function + +from ..utils import quant_get_local_rank +from ._division_transpose import fp8_division_transpose +from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose +from .linear import fp8_linear_backward, fp8_linear_forward + + +@dataclass +class DefaultArgs: + fabit: int + fwbit: int + bobit: int + + +class FP8Linear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0): + super().__init__(in_features, out_features, bias, device) + + if args is None: # I do not want to pass a new argument to OLMo so just use this method + args = DefaultArgs( + fabit=os.environ["FABIT_FP8Linear"], + fwbit=os.environ["FWBIT_FP8Linear"], + bobit=os.environ["BOBIT_FP8Linear"], + ) + self.args = deepcopy(args) + + if quant_get_local_rank() == 0: + print(f"[qlinear debug] Apply QLinear, {layer_idx}") + + self.layer_idx = layer_idx + self.layer_name = None + + def forward(self, Input): + if self.training: + # if False: + output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name) + else: + output = F.linear(Input, self.weight, self.bias) + + return output + + +# if int(os.environ.get("LOCAL_RANK")) == 0: +# import IPython +# IPython.embed() +# else: +# import time +# time.sleep(1000) + +# class QuantLinearTE(Function): +# @staticmethod +# def forward(ctx, input, weight, bias, args, layer_type): +# ctx.saved = input, weight, bias, args, layer_type +# return F.linear(input, weight, bias) + +# @staticmethod +# def backward(ctx, grad_output): +# input, weight, bias, args, layer_type = ctx.saved + +# C_in = input.shape[-1] +# C_out = grad_output.shape[-1] + +# grad_output_flatten = grad_output.reshape(-1, C_out) +# input_flatten = input.reshape(-1, C_in) + +# if grad_output_flatten.dtype == input_flatten.dtype: +# grad_weight = grad_output_flatten.t().mm(input_flatten) +# else: +# grad_weight = grad_output_flatten.float().t().mm(input_flatten) + +# if grad_output_flatten.dtype == weight.dtype: +# grad_input = grad_output_flatten.mm(weight) +# else: +# grad_input = grad_output_flatten.float().mm(weight) + +# if bias is not None: +# grad_bias = grad_output_flatten.sum(0) +# else: +# grad_bias = None + +# grad_input_transform = grad_input.reshape(input.size()) + +# return grad_input_transform, grad_weight, grad_bias, None, None + + +class QuantLinearTE(Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16) + def forward(ctx, input, weight, bias, args, layer_name): + + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit) + Qinput, Iscale, Qinput_t = fp8_quantize_pertensor_transpose(input, 16, args.fabit, transpose_output_2d=True) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit) + Qweight, Wscale, Qweight_t = fp8_quantize_pertensor_transpose(weight, 16, args.fwbit, transpose_output_2d=True) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name + fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias) + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + output = F.linear(input, weight, bias) + + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if quant_get_local_rank() == 0: + print( + f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}" + ) + + return fc_output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved + + time_bench = os.getenv("TIME_BENCH") + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False) + Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_pertensor_transpose( + grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True + ) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + grad_input, grad_weight = fp8_linear_backward( + Qinput_t, + Iscale, + Qgrad_output, + Gscale, + Qgrad_output_t, + Qweight_t, + Wscale, + 16, + bias, + stochastic=False, + dgrad_quantize=False, + ) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + if bias is not None: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + else: + grad_bias = None + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + + # ========== BF16 ========== + C_in = Qinput_t.shape[0] + C_out = grad_output.shape[-1] + grad_output_flatten = grad_output.reshape(-1, C_out) + input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16) + weight = Qweight_t.t().to(torch.bfloat16) + + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + if grad_output_flatten.dtype == input_flatten.dtype: + _grad_weight = grad_output_flatten.t().mm(input_flatten) + else: + _grad_weight = grad_output_flatten.float().t().mm(input_flatten) + + if grad_output_flatten.dtype == weight.dtype: + _grad_input = grad_output_flatten.mm(weight) + else: + _grad_input = grad_output_flatten.float().mm(weight) + + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if quant_get_local_rank() == 0: + print( + f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}" + ) + + return grad_input, grad_weight, grad_bias, None, None diff --git a/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a1f57f3d1db8b19da687fe823a83f35d97d7d6 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py @@ -0,0 +1,303 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +import torch +import triton +import triton.language as tl + +from ._division import _stochastic_rounding +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit + +"""FP8 LayerNorm. Forward + Backward""" +"""Forward: input uses 1 * 16 quantization""" +"""Forward: output use per-tensor quantization.""" +"""Backward: input uses full-precision/BF16.""" +"""Backward: output uses full-precision/BF16""" +"""Support 3D input, but need to first flatten to 2D to perform calculation.""" + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + SX, # pointer to the scale of input + Y, # pointer to the output + SY, # pointer to the scale of output + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + eps, # epsilon to avoid division by zero + fp8_max, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + SX += row * scale_stride + + mean = 0 + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + + # Dequantize and swish calculation + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, (N2,)) + + # Compute mean and Variance + mean = tl.sum(x, axis=0) / N + # Compute variance + _var = (x - mean) * (x - mean) + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + x_hat = (x - mean) * rstd + + # Scale calculation + abs_x_hat = tl.abs(x_hat) + max_val = tl.max(abs_x_hat, axis=0) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = scale_output.to(SY.type.element_ty) + + tl.store(SY + row, scale_output) + + # Write output + tl.store(Y + cols, x_hat, mask=cols < N) + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _layer_norm_bwd_dx_fused( + DX, # pointer to the input gradient + DY, # pointer to the output gradient + X, # pointer to the input + SX, # pointer to the input + noise_ptr, # noise for stochastic + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X + N2: tl.constexpr, # number of columns in X + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + SX += row * scale_stride + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd + xhat = tl.where(mask, xhat, 0.0) + dy = tl.where(mask, dy, 0.0) + c1 = tl.sum(xhat * dy, axis=0) / N + c2 = tl.sum(dy, axis=0) / N + dx = (dy - (xhat * c1 + c2)) * rstd + + if STOCHASTIC: + # noise_ptr += row * stride + # noise_block_ptr = noise_ptr + cols + # noise = tl.load(noise_block_ptr, mask=mask, other=0.) + + noise_offset = row * stride + cols + noise = tl.rand(0, noise_offset) + + dx = _stochastic_rounding(dx, noise, e_bit, m_bit) + + dx = dx.to(DX.type.element_ty) + + # Write dx + tl.store(DX + cols, dx, mask=mask) + + +def fp8_layernorm_noparam_forward(x, s_x, QB, eps, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # allocate output + M, N = x.shape + _, SN = s_x.shape + y = torch.empty_like(x, dtype=torch.float32) + s_y = torch.empty( + (M,), dtype=torch.bfloat16, device=x.device + ) # We do this because we apply per-tensor quantization for it afterwards + mean = torch.empty((M,), dtype=torch.float32, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # heuristics for number of warps + num_warps = 8 + fp8MaxValue = FP8_MAX_VALUE[x.dtype] + + N2 = triton.next_power_of_2(N) + SN2 = N2 // QB + # enqueue kernel + _layer_norm_fwd_fused[(M,)]( # + x, + s_x, + y, + s_y, + mean, + rstd, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + eps, + fp8MaxValue, + SCALE_MIN_THRES=SCALE_MIN_THRES, # + num_warps=num_warps, + num_ctas=1, + ) + # reduction + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, y.shape[-1]) + + return qy, s_y_max, qy_t, (mean, rstd, num_warps) + + +def fp8_layernorm_noparam_backward(x, s_x, g, QB, m, v, num_warps, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + # noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # heuristics for amount of parallel reduction stream for DW/DB + dx = torch.empty_like(g, dtype=torch.bfloat16) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + M, N = g.shape + _, SN = s_x.shape + + N2 = triton.next_power_of_2(N) + SN2 = triton.next_power_of_2(SN) + _layer_norm_bwd_dx_fused[(M,)]( # + dx, + g, + x, + s_x, + noise, + m, + v, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + num_warps=num_warps, + ) + + if batched: + dx = dx.reshape(BS, -1, dx.shape[-1]) + + return dx diff --git a/llava/model/coat/activation/real_quantization/func_quantize.py b/llava/model/coat/activation/real_quantization/func_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..f63942146890a02132b431a10df39f22a65ebec4 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_quantize.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy + +import torch +import torch.nn as nn + +from ._quantize import fp8_quantize +from ._quantize_pertensor import fp8_quantize_pertensor + + +class Coat_quantize_bgn(nn.Module): + def __init__(self, args=None, layer_type=""): + super().__init__() + self.args = deepcopy(args) + self.fp8type = self.args.fabit + self.layer_type = layer_type + + def forward(self, input): + if self.training: + return Coat_quantize_bgn_func.apply(input, self.args.group_size, self.fp8type) + else: + return input, None, None + + +class Coat_quantize_bgn_func(torch.autograd.Function): + @staticmethod + def forward(ctx, input, group_size, fp8type): + """ + (Qoutput, Oscale) uses 1 * 16 quantization + """ + Qoutput, Oscale = fp8_quantize(input, group_size, fp8type) + # For autograd + Qoutput = Qoutput.view(torch.float8_e4m3fn) + ctx.saved = group_size + return input, Qoutput, Oscale + + @staticmethod + def backward(ctx, grad_output, Qgrad_output, Gscale): + """ + (Qgrad_output, Gscale) uses 1 * 16 quantization + """ + return grad_output, None, None + + +class Coat_quantize_end(nn.Module): + def __init__(self, args=None, layer_type=""): + super().__init__() + self.args = deepcopy(args) + self.fp8type = self.args.babit + self.layer_type = layer_type + + def forward(self, input, Qinput, Iscale): + if self.training: + return Coat_quantize_end_func.apply(input, Qinput, Iscale, self.args.group_size, self.fp8type) + else: + return input + + +class Coat_quantize_end_func(torch.autograd.Function): + @staticmethod + def forward(ctx, input, Qinput, Iscale, group_size, fp8type): + """ + (Qinput, Iscale) uses 1 * 16 quantization + """ + ctx.saved = group_size, fp8type + + return input + + @staticmethod + def backward(ctx, grad_output): + """ + (Qgrad_output, Gscale) uses per-tensor quantization + """ + + group_size, fp8type = ctx.saved + Qgrad_output, Gscale, Gscale_g16 = fp8_quantize_pertensor(grad_output, group_size, fp8type, stochastic=False) + + # For autograd + Qgrad_output = Qgrad_output.view(torch.float8_e4m3fn) + + return grad_output, Qgrad_output, Gscale_g16, None, None diff --git a/llava/model/coat/activation/real_quantization/func_rmsnorm.py b/llava/model/coat/activation/real_quantization/func_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a8d38e2909d7a45a7792e9b1502af5aa7dee95 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_rmsnorm.py @@ -0,0 +1,345 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +import torch +import triton +import triton.language as tl + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES + +"""RMSNorm Forward + Backward""" +"""Forward: Input uses 1 * 16 group quantization""" +"""Forward: Output uses per-tensor quantization""" +"""Backward: Input uses full-precision/BF16.""" +"""Backward: Output uses full-precision/BF16.""" + +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _rms_norm_fwd_fused( + X, # pointer to the input + SX, # pointer to the scale of input + Y, # pointer to the output + SY, # pointer to the scale of output + W, # Weight + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + eps, # epsilon to avoid division by zero + fp8_max, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + SX += row * scale_stride + + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + + # Dequantize and swish calculation + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + + # Compute variance + _var = x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + w = tl.load(W + cols, mask=cols < N, other=0.0) + x_hat = x * rstd + y = x_hat * w + + # Scale calculation + abs_y = tl.abs(y) + max_val = tl.max(abs_y) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + tl.store(SY + row, scale_output) + + # Write output + tl.store(Y + cols, y, mask=cols < N) + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _rms_norm_bwd_dx_fused( + DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, + X, # pointer to the input + SX, # pointer to the input + W, # weight + Rstd, # pointer to the 1/std + Lock, + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + SX += row * scale_stride + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + # Load weight + w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row).to(tl.float32) + + # Compute dx + xhat = x * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - (xhat * c1)) * rstd # layer norm have c2 term, rmsnorm do not + + dx = dx.to(DX.type.element_ty) + + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + tl.store(DW, partial_dw, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _rms_norm_bwd_dwdb( + DW, # pointer to the partial sum of weights gradient + FINAL_DW, # pointer to the weights gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # Map the program id to the elements of DW and DB it should compute. + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Iterate through the rows of DW and DB to sum the partial sums. + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + # Write the final sum to the output. + sum_dw = tl.sum(dw, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + + +def fp8_rmsnorm_forward(x, s_x, w, QB, eps, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # allocate output + M, N = x.shape + _, SN = s_x.shape + y = torch.empty_like(x, dtype=torch.bfloat16) + s_y = torch.empty((M,), dtype=torch.bfloat16, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # heuristics for number of warps + num_warps = 8 + fp8MaxValue = FP8_MAX_VALUE[x.dtype] + + N2 = triton.next_power_of_2(N) + SN2 = N2 // QB + + # import os + # if int(os.environ.get("LOCAL_RANK")) == 7: + # print(x.device, x.shape, x.dtype, s_x.shape, s_x.dtype, y.shape, y.dtype, s_y.shape, s_y.dtype, w.shape, w.dtype, rstd.shape, rstd.dtype, x.stride(0), s_x.stride(0), N, N2, SN, SN2, QB, "\n") + # enqueue kernel + _rms_norm_fwd_fused[(M,)]( # + x, + s_x, + y, + s_y, + w, + rstd, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + eps, + fp8MaxValue, + SCALE_MIN_THRES=SCALE_MIN_THRES, # + num_warps=num_warps, + num_ctas=1, + ) + + # reduction + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, y.shape[-1]) + + return qy, s_y_max, qy_t, (w.clone(), rstd, num_warps) + + +def fp8_rmsnorm_backward(x, s_x, g, w, v, QB, num_warps): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + M, N = g.shape + _, SN = s_x.shape + + GROUP_SIZE_M = 128 + # heuristics for amount of parallel reduction stream for DW/DB + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=w.dtype, device=w.device) + dw = torch.empty((N,), dtype=w.dtype, device=w.device) + + dx = torch.empty_like(g, dtype=torch.bfloat16) + + N2 = triton.next_power_of_2(N) + SN2 = triton.next_power_of_2(SN) + _rms_norm_bwd_dx_fused[(M,)]( # + dx, + g, + _dw, + x, + s_x, + w, + v, + locks, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + SCALE_MIN_THRES=SCALE_MIN_THRES, + num_warps=num_warps, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + + if batched: + dx = dx.reshape(BS, -1, dx.shape[-1]) + + grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] + # accumulate partial sums in separate kernel + _rms_norm_bwd_dwdb[grid](_dw, dw, min(GROUP_SIZE_M, M), N, BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, num_ctas=1) # # + + return dx, dw diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd.py b/llava/model/coat/activation/real_quantization/gelu_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8772707f7f5e6ef9b68b9c29f2b1b329f2c3177d --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_bwd.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""GELU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses full-precision/BF16""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_backward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of GELU's backward + pi = float(torch.pi) + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi) + gelu_output = cdf + exp + + gelu_output = gelu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_backward(x, s_x, g, QB, fp8type): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_backward_kernel[grid]( + y, + s_y, + x, + s_x, + g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2b36bcb8c51d33392cd740c3dcbf6de4c9f082 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""GELU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_backward_legacy_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and gelu calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of GELU's backward + pi = float(torch.pi) + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi) + gelu_output = cdf + exp + + gelu_output = gelu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_backward_legacy(x, s_x, g, s_g, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_backward_legacy_kernel[grid]( + y, + s_y, + x, + s_x, + g, + s_g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max = fp8_division(y, QB, g.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/gelu_fwd.py b/llava/model/coat/activation/real_quantization/gelu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..547c6f8339feef79870ade12a407c852631ae136 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_fwd.py @@ -0,0 +1,209 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""GELU Activation Forward""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + +__all__ = ["fp8_gelu_forward"] + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_forward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # Actual Calculation of GeLU + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + gelu_output = cdf * input + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_forward(x, s_x, QB, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(x, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=s_x.dtype) + fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_forward_kernel[grid]( + y, + s_y, + x, + s_x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy.shape[-1]) + + return qy, s_y_max, qy_t diff --git a/llava/model/coat/activation/real_quantization/linear.py b/llava/model/coat/activation/real_quantization/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..d643f18f51432efc3a2f80c55b73a18e99f0eb75 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/linear.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +try: + from ._division import _stochastic_rounding + from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8 +except: + from common import SCALE_MIN_THRES, FP8_MAX_VALUE, convert_str_to_fp8, convert_fp8_to_embit + from COAT.coat.activation.real_quantization._division import _stochastic_rounding + +import os +import time + +"""Linear Layer Forward + Backward""" +"""Input uses per-tensor quantization""" +"""Output is full-precision/BF16 (for FlashAttention) or 1 * 16 quantization (for the rest)""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +def get_configs_io_block(): + configs = [] + for nstages in [3]: + for block_m in [128, 256]: + for block_n in [128, 256]: + for block_k in [128, 256]: + for nwarps in [8]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +# @triton.autotune( +# configs=get_configs_io_block(), +# key=["M", "N", "K"], +# ) +@triton.jit +def _fp8matmul_kernel( + A, + B, + C, + noise_ptr, # noise for stochastic + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, ## + Scale_A, + Scale_B, + Scale_C, + stride_scm, + stride_scn, + output_quantize: tl.constexpr, + QB: tl.constexpr, # default to use 1 * 16 quantization + BIAS, + fp8_max: tl.constexpr, + e_bit: tl.constexpr, + m_bit: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) + + acc = tl.dot(a, b, acc) + + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + scale_a = tl.load(Scale_A) + scale_b = tl.load(Scale_B) + scale_ab = scale_a.to(tl.float32) * scale_b.to(tl.float32) + # fp8 dequantize + acc = acc * scale_ab + + if BIAS: + bias = tl.load(BIAS + rbn) + acc = acc + bias + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + if output_quantize: + acc = tl.reshape(acc, (BLOCK_M, BLOCK_N // QB, QB)) + abs_acc = tl.abs(acc) + acc_max = tl.max(abs_acc, axis=2) + SCALE_MIN_THRES + # tl.device_print("acc_max", acc_max) + acc_scale = acc_max / fp8_max + # tl.device_print("acc_scale", acc_scale) + acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB, 1)) + acc = tl.fdiv(acc, acc_scale) + acc = tl.reshape(acc, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + noise_block_ptr = noise_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + noise = tl.load(noise_block_ptr, boundary_check=(0, 1)) + acc = _stochastic_rounding(acc, noise, e_bit, m_bit) + + acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB)) + acc_scale = acc_scale.to(Scale_C.type.element_ty) + acc = acc.to(C.dtype.element_ty) + + rsm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rsn = pid_n * BLOCK_N // QB + tl.arange(0, BLOCK_N // QB) + Scale_C = Scale_C + (rsm[:, None] * stride_scm + rsn[None, :] * stride_scn) + + tl.store(C, acc, mask=mask) + tl.store(Scale_C, acc_scale) + + else: + # handles write-back with reduction-splitting + acc = acc.to(C.dtype.element_ty) + tl.store(C, acc, mask=mask) + + +def fp8matmul(a, b, output_quantize, scale_a, scale_b, QB, bias=None, stochastic=False): + # Deal with batched input + if len(a.shape) == 3: + BS, batched = a.shape[0], True + a = a.reshape(-1, a.shape[2]) + else: + batched = False + + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + fp8MaxValue = FP8_MAX_VALUE[a.dtype] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[a.dtype] + + # Allocates output. + if output_quantize: + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # c = torch.empty((M, N), device=a.device, dtype=torch.float32) + scale_c = torch.empty((M, N // QB), device=a.device, dtype=torch.bfloat16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + scale_c = torch.empty( + (1, 1), device=a.device, dtype=torch.bfloat16 + ) # This line is useless, equivalent to scale_c = None + + if stochastic: + noise = torch.empty_like(c, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _fp8matmul_kernel[grid]( + a, + b, + c, + noise, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + scale_a, + scale_b, + scale_c, + scale_c.stride(0), + scale_c.stride(1), + output_quantize=output_quantize, + QB=QB, + BIAS=bias, + fp8_max=fp8MaxValue, + e_bit=e_bit, + m_bit=m_bit, + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + BLOCK_M=128, + BLOCK_N=256, + BLOCK_K=128, + GROUP_M=8, + num_stages=3, + num_warps=8, + ) + # Reshape output to batch + if batched: + c = c.reshape(BS, -1, N) + if output_quantize: + scale_c = scale_c.reshape(BS, -1, N // QB) + return c, scale_c + else: + if output_quantize: + scale_c = scale_c.reshape(M, N // QB) + return c, scale_c + return c + + +def fp8_linear_forward(x, s, w, s_w, output_quantize, QB, bias=None): + assert s.numel() == 1, f"X uses per-tensor quantization in linear forward, but the scale shape is {s.shape}" + assert s_w.numel() == 1, f"W uses per-tensor quantization in linear forward, but the scale shape is {s_w.shape}" + + w_t = w.t() + return fp8matmul(x, w_t, output_quantize, s, s_w, QB, bias) + + +# def fp8_linear_forward(x, s, w, s_w, output_quantize, QB): +# print("you are using the wrong linear function. ") +# w_t = w.t() +# if output_quantize: +# return fp8matmul(x, w_t, True, s, s_w, QB) +# else: +# y = fp8matmul(x, w_t, False, s, s_w, QB) + +# return y + + +def fp8_linear_backward( + x_t, s, g, s_g, g_t, w_t, s_w, QB, bias=None, stochastic=False, dgrad_quantize=False +): # dgrad_quantize=True for backward before flashattention + assert s.numel() == 1, f"X uses per-tensor quantization in linear backward, but the scale shape is {s.shape}" + assert s_g.numel() == 1, f"G uses per-tensor quantization in linear backward, but the scale shape is {s.shape}" + assert s_w.numel() == 1, f"W uses per-tensor quantization in linear backward, but the scale shape is {s_w.shape}" + + batched = False + if len(g.shape) == 3: # others must be of 2D! + batched = True + BS = g.shape[0] + g = g.reshape(-1, g.shape[-1]) + + w_t_t = w_t.t() + x_t_t = x_t.t() + if dgrad_quantize: + y, s_y = fp8matmul(g, w_t_t, True, s_g, s_w, QB, stochastic=stochastic) + else: + y = fp8matmul(g, w_t_t, False, s_g, s_w, QB) + + w_g = fp8matmul(g_t, x_t_t, False, s_g, s, QB) + + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + if dgrad_quantize: + if s_y.numel() > 1: + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + if dgrad_quantize: + return y, s_y, w_g + else: + return y, w_g diff --git a/llava/model/coat/activation/real_quantization/mul_bwd.py b/llava/model/coat/activation/real_quantization/mul_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..27f8b2839495d6d4f92fa35d537ea75e3e04c0c2 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses full-precision/BF16""" +"""Output1 (Gate) uses full-precision/BF16""" +"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward( + x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(g, dtype=torch.bfloat16) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_kernel[grid]( + y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) + else: + # Per-tensor quantization + s_y2_max = s_y2.max() + qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, qy2_t) diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..803789e83cb27330d6195a515805440af45cc528 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py @@ -0,0 +1,374 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import _stochastic_rounding +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses 1 * 16 group quantization""" +"""Output1 (Gate) uses 1 * 16 quantization""" +"""Output2 (Up) uses per-tensor quantization, but should be quantized outside this function""" # Although it is per-tensor quantization, we only apply per-group quantization here, and the reduction should be performed outside this function. +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_legacy_kernel( + output1_ptr, + output1_scale_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output1_stride_0, + output1_stride_1, # output stride + s_output1_stride_0, + s_output1_stride_1, # scale of output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and swish calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + # Quantize the grad 1 - Scale calculation + abs_grad1 = tl.abs(grad1) + max_val = tl.max(abs_grad1, axis=2) + SCALE_MIN_THRES + scale_grad1 = max_val / fp8_max + scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + grad1 = tl.fdiv(grad1, scale_grad1) # do not quantize the output due to the data flow + scale_grad1 = scale_grad1.to(output1_scale_ptr.type.element_ty) + scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN)) + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input1_stride_0, input1_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input1_stride_0 + offs_n[None, :] * input1_stride_1 + noise = tl.rand(0, noise_offset) + + grad1 = _stochastic_rounding(grad1, noise, e_bit, m_bit) + + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output1_ptr = tl.make_block_ptr( + base=output1_scale_ptr, + shape=(M, SN), + strides=(s_output1_stride_0, s_output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + tl.store(scale_output1_ptr, scale_grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward_legacy( + x1, s_x1, x2, s_x2, g, s_g, QB, stochastic=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + y1 = torch.empty_like(g, dtype=g.dtype) + s_y1 = torch.empty_like(s_g, dtype=s_g.dtype) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[g.dtype] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_legacy_kernel[grid]( + y1, + s_y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + s_g, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y1.stride(0), + y1.stride(1), + s_y1.stride(0), + s_y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + s_y1 = s_y1.reshape(BS, -1, s_y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, s_y1, y2, s_y2 diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d23c26f6c86284b8d614ae0f90819bc2e4bc1718 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses full-precision/BF16""" +"""Output1 (Gate) uses full-precision/BF16""" +"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_silu_forward_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- recompute SiLU --- + # Actual Calculation of SiLU + sigmoid = 1 / (1.0 + libdevice.exp(-input1)) + silu_output = sigmoid * input1 + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(input1_ptr.type.element_ty) + + input1 = silu_output * scale_output + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward_silu_forward( + x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(g, dtype=torch.bfloat16) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_silu_forward_kernel[grid]( + y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) + else: + # Per-tensor quantization + s_y2_max = s_y2.max() + qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, qy2_t) diff --git a/llava/model/coat/activation/real_quantization/mul_fwd.py b/llava/model/coat/activation/real_quantization/mul_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..0037c7220c6f1bb0b7b9b5946a7d2553b44c7426 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_fwd.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Element-wise Multiplication Forward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + +fp8_max_value = { + torch.float8_e4m3fn: 448, + torch.float8_e5m2: 57344, +} + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def fp8_mul_forward_kernel( + output_ptr, + output_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # Actual Calculation of SiLU + mul_output = input1 * input2 + + # Quantize Scale calculation + abs_output = tl.abs(mul_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # mul_output = tl.fdiv(mul_output, scale_output) # do not quantize the output since it should use per-tensor quantization afterwards + mul_output = mul_output.to(output_ptr.type.element_ty) + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + mul_output = tl.reshape(mul_output, (BLOCK_M, BLOCK_N)) + + # debug + # mul_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, mul_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_mul_forward(x1, s_x1, x2, s_x2, QB, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + y = torch.empty_like(x1, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x1, dtype=s_x1.dtype) + fp8MaxValue = fp8_max_value[x1.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + fp8_mul_forward_kernel[grid]( + y, + s_y, + x1, + s_x1, + x2, + s_x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x2.dtype, s_y_max) + qy = qy.to(x2.dtype) + qy_t = qy_t.to(x2.dtype) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy.shape[-1]) + + return qy, s_y_max, qy_t diff --git a/llava/model/coat/activation/real_quantization/silu_bwd.py b/llava/model/coat/activation/real_quantization/silu_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..052bf66e6af6f5f3ce40e08e7b5c8c92e97bfbe3 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_bwd.py @@ -0,0 +1,255 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""SiLU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses full-precision / BF16""" +"""Output uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_backward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of SiLU's backward + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid + input * sigmoid * (1 - sigmoid) + silu_output = silu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_backward( + x, s_x, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_backward_kernel[grid]( + y, + s_y, + x, + s_x, + g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y + else: + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, fp8type, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max, qy_t + + # # Per-tensor quantization + # s_y_max = s_y.max() + # qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max) + + # # Recover 2D to 3D + # if batched: + # y = y.reshape(BS, -1, y.shape[-1]) + # qy = qy.reshape(BS, -1, qy.shape[-1]) + # s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + # return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..88de423ad973f34718f8d44e1d2e37b3113e1737 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""SiLU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization, but should be quantized outside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_backward_legacy_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and silu calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of SiLU's backward + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid + input * sigmoid * (1 - sigmoid) + silu_output = silu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_backward_legacy(x, s_x, g, s_g, QB, stochastic=False): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_backward_legacy_kernel[grid]( + y, + s_y, + x, + s_x, + g, + s_g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/real_quantization/silu_fwd.py b/llava/model/coat/activation/real_quantization/silu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6dadc42b2e25ad548ec3e2f30cc146b8ce01d0 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_fwd.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +try: + from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block +except: + from common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""SiLU Activation Forward""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_forward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # Actual Calculation of SiLU + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid * input + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_forward(x, s_x, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(x, dtype=x.dtype) + s_y = torch.empty_like(s_x, dtype=s_x.dtype) + fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_forward_kernel[grid]( + y, + s_y, + x, + s_x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/utils.py b/llava/model/coat/activation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b14b79e29cac93f6049c6fa81ff537bd2a5cb8fd --- /dev/null +++ b/llava/model/coat/activation/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from collections import defaultdict + +import torch + + +def quant_get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK") or 0) + + +record_memory_allocated = defaultdict(list) diff --git a/llava/model/coat/fp8_trainer.py b/llava/model/coat/fp8_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9d68029db9fc7fbbd6af01b34d2cf342b4294f --- /dev/null +++ b/llava/model/coat/fp8_trainer.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# In this file I only add the logic in line 411 to 415. The rest remains unchanged compared with the original Trainer Class. + +import math +import os +import shutil +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from packaging import version +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler +from transformers import Trainer +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.integrations import hp_params +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from transformers.integrations.tpu import tpu_spmd_dataloader +from transformers.trainer_callback import ExportableState, TrainerState +from transformers.trainer_pt_utils import get_model_param_count +from transformers.trainer_utils import HPSearchBackend, TrainOutput, has_length, speed_metrics +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + is_accelerate_available, + is_apex_available, + is_bitsandbytes_available, + is_datasets_available, + is_galore_torch_available, + is_grokadamw_available, + is_in_notebook, + is_ipex_available, + is_liger_kernel_available, + is_lomo_available, + is_peft_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_schedulefree_available, + is_torch_compile_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + logging, +) + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_safetensors_available(): + import safetensors.torch + +if is_peft_available(): + from peft import PeftModel + + +if is_accelerate_available(): + from accelerate import Accelerator + from accelerate import __version__ as accelerate_version + from accelerate import skip_first_batches + from accelerate.utils import ( + DistributedDataParallelKwargs, + DistributedType, + GradientAccumulationPlugin, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + from accelerate.utils import DeepSpeedSchedulerWrapper + +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + +from .activation.models._fp8manager import FP8Manager + +logger = logging.get_logger(__name__) + + +class CoatFP8Trainer(Trainer): + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + # NOTE: FP8 related + if total_batched_samples % args.gradient_accumulation_steps == 0: + FP8Manager.is_first_microbatch = True + else: + FP8Manager.is_first_microbatch = False + + total_batched_samples += 1 + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + self.state.num_input_tokens_seen += ( + torch.sum( + self.accelerator.gather( + torch.tensor( + inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 + ) + ) + ) + .cpu() + .item() + ) + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) + + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or + # last step in epoch but step is always smaller than gradient_accumulation_steps + is_last_step_and_steps_less_than_grad_acc + ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc: + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED: + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) diff --git a/llava/model/coat/optimizer/fp8_adamw.py b/llava/model/coat/optimizer/fp8_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..7776ef1539e07773285cf0d220811d6e73070542 --- /dev/null +++ b/llava/model/coat/optimizer/fp8_adamw.py @@ -0,0 +1,515 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict, defaultdict +from copy import deepcopy +from itertools import chain +from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union + +import qoptim_cuda +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing_extensions import ParamSpec, Self, TypeAlias + +StateDict: TypeAlias = Dict[str, Any] + +convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} + + +class CoatAdamW(Optimizer): + def __init__( + self, + qargs, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + fused: Optional[bool] = None, + ): + self.qargs = qargs + assert self.qargs.first_order_expansion == self.qargs.second_order_expansion + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + fused=fused, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor(step_val, dtype=torch.float32) + + def _init_group( + self, + group, + params_with_grad, + grads, + amsgrad, + use_expansion, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a')) + # print(f'Param shape: {p.shape}, {p.device}') + + # State initialization + if len(state) == 0: + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = torch.tensor(0.0) + + # Should be torch.float8_e4m3fn + first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit] + second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit] + scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size + + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format) + state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) + if use_expansion: + state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format) + state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) + if use_expansion: + state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format) + + exp_avgs.append(state["exp_avg"]) + scale_exp_avgs.append(state["scale_exp_avg"]) + if use_expansion: + expand_exp_avgs.append(state["expand_exp_avg"]) + sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + scale_exp_avg_sqs.append(state["scale_exp_avg_sq"]) + if use_expansion: + expand_exp_avg_sqs.append(state["expand_exp_avg_sq"]) + sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + @torch._disable_dynamo + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups) + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key) + elif isinstance(value, dict): + return { + k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + @staticmethod + def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + key: Hashable = None, + ) -> torch.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=param.device) + else: + return value + else: + assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32] + return value.to(device=param.device) # do not cast optimizer states + # if param.is_floating_point(): + # return value.to(dtype=param.dtype, device=param.device) + # else: + # return value.to(device=param.device) + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + scale_exp_avgs = [] + expand_exp_avgs = [] + sqrt_minmax_exp_avgs = [] + exp_avg_sqs = [] + scale_exp_avg_sqs = [] + expand_exp_avg_sqs = [] + sqrt_minmax_exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group["amsgrad"] + use_expansion = self.qargs.first_order_expansion in ["expansion", "true"] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + amsgrad, + use_expansion, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + Coatadamw( + self.qargs, + params_with_grad, + grads, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + use_expansion=use_expansion, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + qgroup_size=self.qargs.qgroup_size, + expand_min=self.qargs.expand_min, + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def Coatadamw( + qargs, + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + scale_exp_avgs: List[Tensor], + expand_exp_avgs: List[Tensor], + sqrt_minmax_exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + scale_exp_avg_sqs: List[Tensor], + expand_exp_avg_sqs: List[Tensor], + sqrt_minmax_exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + use_expansion: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + qgroup_size: int, + expand_min: int, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + func = _single_tensor_Coatadamw + + func( + qargs, + params, + grads, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + use_expansion=use_expansion, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + qgroup_size=qgroup_size, + expand_min=expand_min, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference + if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): + return x.sqrt() + else: + return sqrt(x) + + +def _single_tensor_Coatadamw( + qargs, + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + scale_exp_avgs: List[Tensor], + expand_exp_avgs: List[Tensor], + sqrt_minmax_exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + scale_exp_avg_sqs: List[Tensor], + expand_exp_avg_sqs: List[Tensor], + sqrt_minmax_exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + use_expansion: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + qgroup_size: int, + expand_min: int, +): + + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] + # First order + exp_avg = exp_avgs[i] + scale_exp_avg = scale_exp_avgs[i] + # Second order + exp_avg_sq = exp_avg_sqs[i] + scale_exp_avg_sq = scale_exp_avg_sqs[i] + step_t = state_steps[i] + + # print(len(exp_avg.unique()), len(exp_avg_sq.unique())) + # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a')) + + # update step + step_t += 1 + step = int(step_t.item()) + + # Perform Optimizer Step + if use_expansion: + expand_exp_avg = expand_exp_avgs[i] + sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i] + expand_exp_avg_sq = expand_exp_avg_sqs[i] + sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i] + + qoptim_cuda.fp8_adamw_expand_step( + param, + grad, + exp_avg, + scale_exp_avg, + expand_exp_avg, + sqrt_minmax_exp_avg, + exp_avg_sq, + scale_exp_avg_sq, + expand_exp_avg_sq, + sqrt_minmax_exp_avg_sq, + beta1, + beta2, + lr, + weight_decay, + eps, + step, + qgroup_size, + expand_min, + ) + + else: + qoptim_cuda.fp8_adamw_step( + param, + grad, + exp_avg, + scale_exp_avg, + exp_avg_sq, + scale_exp_avg_sq, + beta1, + beta2, + lr, + weight_decay, + eps, + step, + qgroup_size, + ) diff --git a/llava/model/coat/optimizer/kernels/bindings.cpp b/llava/model/coat/optimizer/kernels/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa3dbf8b145c263000ce3b7151bb73714c55f655 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/bindings.cpp @@ -0,0 +1,10 @@ +#include + +#include "include/fp8_adamw.h" +#include "include/fp8_adamw_expand.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fp8_adamw_step", &FP8_AdamW, "Update the quantized optimizer states"); + m.def("fp8_adamw_expand_step", &FP8_AdamW_expand, + "Update the quantized optimizer states, use polynomial expander"); +} diff --git a/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..0d6753b8a2e35f20ec2070c6bce2c9e4acefa0d7 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e40fca6d032a0a3094e50a4fdb04e78998819de2652f0a82e49f49225dd46ba9 +size 235960 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o new file mode 100644 index 0000000000000000000000000000000000000000..be1ed8ce2592d7ce94e249fd52af112a39de1af2 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:daebebfbb9d1a3dbb5e22acb50af317aa74b437f7aa965bd959b19010c5fd144 +size 244976 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..837c2b25b4a1759b529b7e5bf2f74a781a60b73c --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9c6832be5dccfce3f25ae735c67459afffd12f8f747b5018b3121da3dfd2d7 +size 215440 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..b6c5c480b055b909abffc646f69cd563c615cc47 Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o differ diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..ec6051ec3846f04b0f2c0f3bbd6abce6a626c3b5 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:799978e9ba5a24e75b58d0e218cf522874bb983d15fee15adb02d9ac0f2d7a10 +size 216240 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..b29913e078a46679c6aedca63accff5ace9ca5b6 Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o differ diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile new file mode 100644 index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile @@ -0,0 +1,6 @@ +all: + nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90 +run: + ./nvcc_qoptim +clean: + rm -f nvcc_qoptim diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu new file mode 100644 index 0000000000000000000000000000000000000000..79f370e76aaf3b99fc53c7fe053731dc9358729f --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu @@ -0,0 +1,478 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cg = cooperative_groups; +#define WARPSIZE 32 +#define QGROUPSIZE 128 +#define QUANT_MIN_VAL 1e-20 + +template +inline float fp8_dtype_max(const T &variable) { + if (std::is_same::value) { + return 448; + } else if (std::is_same::value) { + return 57344; + } else { + throw "Unsupported data format"; + } +} + +typedef enum { fp8_adamw } myCsrcKernels; + +void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg, + float *fp_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int M, + int N) { + for (int idx = 0; idx < M * N; idx++) { + fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx]; + fp_exp_avg_sq[idx] = + beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + float denom = + (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1; + float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } +} + +template +void printFloatArrayToFile(T *array, int M, int N, + const std::string &outputFileName) { + std::ofstream outputFile(outputFileName); + if (!outputFile.is_open()) { + std::cout << "Failed to open the file." << std::endl; + return; + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int index = i * N + j; + outputFile << std::setw(10) << std::right << std::fixed + << std::setprecision(6) << (float)array[index] << " "; + if (j == N - 1) { + outputFile << "\n"; + } + } + } +} + +template +__global__ void fp8_adamw_csrc( + scalar_t *__restrict__ params, scalar_t *__restrict__ grads, + __nv_fp8_e4m3 *__restrict__ exp_avg, float *__restrict__ scale_exp_avg, + float *__restrict__ expand_exp_avg, float *__restrict__ sqrtminmax_exp_avg, + __nv_fp8_e4m3 *__restrict__ exp_avg_sq, + float *__restrict__ scale_exp_avg_sq, float *__restrict__ expand_exp_avg_sq, + float *__restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, int expand_min, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = sign_exp_avg * + powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) * + sqrtminmax_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + float_exp_avg_sq = + powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) * + sqrtminmax_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedFirstMinVal[32]; + __shared__ float sharedSecondMaxVal[32]; + __shared__ float sharedSecondMinVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float firstMinVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + float secondMinVal = fabsf(float_exp_avg_sq); + // Special Handel + if (idx >= total_elements) { + firstMinVal = __int_as_float(0x7f7fffff); + secondMinVal = __int_as_float(0x7f7fffff); + } + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) { + sharedFirstMaxVal[wid] = firstMaxVal; + sharedFirstMinVal[wid] = firstMinVal; + sharedSecondMaxVal[wid] = secondMaxVal; + sharedSecondMinVal[wid] = secondMinVal; + } + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmin_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + __shared__ float shared_absmin_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + firstMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + secondMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceFirstMinVal = + __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + float reduceSecondMinVal = + __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + } + if (lane == 0) { + shared_absmax_exp_avg = firstMaxVal; + shared_absmin_exp_avg = firstMinVal; + shared_absmax_exp_avg_sq = secondMaxVal; + shared_absmin_exp_avg_sq = secondMinVal; + } + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + // scaling factor before expanding + float fp8MaxVal = 448; + + // dynamic exponent quantization part + firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL; + firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL; + secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL; + + // calculate the ratio and make the scale to center + float firstRatio = firstMaxVal / firstMinVal; + float secondRatio = secondMaxVal / secondMinVal; + float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal); + float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal); + + // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal, + // float_exp_avg); + + // since we use x^k expander, calculate the optimal expanding factor + float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2; + float firstExp = + floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) / + expand_min; // expand_min is set to 8 for example, then the firstExp is + // the multiple of 1/8 + float secondExp = + floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) / + expand_min; + + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = + sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp); + float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp); + + // correspondingly, change the scaling factor + float new_scale_exp_avg = + powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal; + float new_scale_exp_avg_sq = + powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + expand_exp_avg[scale_idx] = firstExp; + expand_exp_avg_sq[scale_idx] = secondExp; + sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax; + sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax; + } +} + +template +void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg, + float *scale_exp_avg, float *expand_exp_avg, + float *sqrtminmax_exp_avg, __nv_fp8_e4m3 *exp_avg_sq, + float *scale_exp_avg_sq, float *expand_exp_avg_sq, + float *sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, + int expand_min, int M, int N) { + if (algo == fp8_adamw) { + const int block_dim = 128; + int grid_dim = (M * N + qgroup_size - 1) / block_dim; + const dim3 gridDim(grid_dim); + const dim3 blockDim(block_dim); + printf("Yes!\n"); + fp8_adamw_csrc<<>>( + params, grads, exp_avg, scale_exp_avg, expand_exp_avg, + sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, expand_exp_avg_sq, + sqrtminmax_exp_avg_sq, beta1, beta2, lr, wd, eps, step, qgroup_size, + expand_min, M * N, int(floor(M * N / 128.))); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in kernel launch: " + << cudaGetErrorString(error) << std::endl; + return; + } + printf("Finish!\n"); + } +} + +float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *, + float *, float *, float *, + __nv_fp8_e4m3 *, float *, float *, + float *, // tensor input + float, float, float, float, float, int, + int, int, // float and int input + int, int), // M and N + int M, int N) { + size_t size_param = M * N * sizeof(float); + size_t size_optim = M * N * sizeof(__nv_fp8_e4m3); + size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float); + + // host tensor + float *h_p, *h_g; + __nv_fp8_e4m3 *h_m, *h_v; + float *h_sm, *h_sv; + float *h_fp_m, *h_fp_v; + float *h_cpd_m, *h_cpd_v; + float *h_sqrtmm_m, *h_sqrtmm_v; + + // device tensor + float *d_p, *d_g; + __nv_fp8_e4m3 *d_m, *d_v; + float *d_sm, *d_sv; + float *d_cpd_m, *d_cpd_v; + float *d_sqrtmm_m, *d_sqrtmm_v; + + // device tensor transfer to host + float *hd_p, *hd_g; + __nv_fp8_e4m3 *hd_m, *hd_v; + float *hd_sm, *hd_sv; + float *hd_fp_m, *hd_fp_v; + float *hd_cpd_m, *hd_cpd_v; + float *hd_sqrtmm_m, *hd_sqrtmm_v; + + h_p = (float *)malloc(size_param); + h_g = (float *)malloc(size_param); + h_m = (__nv_fp8_e4m3 *)malloc(size_optim); + h_v = (__nv_fp8_e4m3 *)malloc(size_optim); + h_sm = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_cpd_m = (float *)malloc(size_scale); + h_cpd_v = (float *)malloc(size_scale); + h_sqrtmm_m = (float *)malloc(size_scale); + h_sqrtmm_v = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_fp_m = (float *)malloc(size_param); + h_fp_v = (float *)malloc(size_param); + cudaMalloc(&d_p, size_param); + cudaMalloc(&d_g, size_param); + cudaMalloc(&d_m, size_optim); + cudaMalloc(&d_v, size_optim); + cudaMalloc(&d_sm, size_scale); + cudaMalloc(&d_sv, size_scale); + cudaMalloc(&d_cpd_m, size_scale); + cudaMalloc(&d_cpd_v, size_scale); + cudaMalloc(&d_sqrtmm_m, size_scale); + cudaMalloc(&d_sqrtmm_v, size_scale); + hd_p = (float *)malloc(size_param); + hd_g = (float *)malloc(size_param); + hd_m = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_v = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_sm = (float *)malloc(size_scale); + hd_sv = (float *)malloc(size_scale); + hd_fp_m = (float *)malloc(size_param); + hd_fp_v = (float *)malloc(size_param); + hd_cpd_m = (float *)malloc(size_scale); + hd_cpd_v = (float *)malloc(size_scale); + hd_sqrtmm_m = (float *)malloc(size_scale); + hd_sqrtmm_v = (float *)malloc(size_scale); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data copy: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + srand(0); + // random initialization for CPU tensor + for (int i = 0; i < M * N; i++) { + h_p[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_g[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + } + for (int i = 0; i < int(ceilf(M * N / 128.)); i++) { + h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_cpd_m[i] = 2; + h_cpd_v[i] = 3. / 8.; + h_sqrtmm_m[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sqrtmm_v[i] = (float)(rand() / (float(RAND_MAX) / 10)); + + printf("scale is %f\n", h_sm[i]); + } + for (int i = 0; i < M * N; i++) { + h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))]; + h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))]; + } + float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8; + int step = 100, qgroup_size = 128, expand_min = 16; + + printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt"); + + cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_cpd_m, h_cpd_m, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_cpd_v, h_cpd_v, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sqrtmm_m, h_sqrtmm_m, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sqrtmm_v, h_sqrtmm_v, size_scale, cudaMemcpyHostToDevice); + + fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data initialization: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + myGPUKernel(d_p, d_g, d_m, d_sm, d_cpd_m, d_sqrtmm_m, d_v, d_sv, d_cpd_v, + d_sqrtmm_v, beta1, beta2, lr, wd, eps, step, qgroup_size, + expand_min, M, N); + + cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_cpd_m, d_cpd_m, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_cpd_v, d_cpd_v, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sqrtmm_m, d_sqrtmm_m, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sqrtmm_v, d_sqrtmm_v, size_scale, cudaMemcpyDeviceToHost); + + for (int i = 0; i < M * N; i++) { + hd_fp_m[i] = pow((float)hd_m[i] * hd_sm[int(floor(i / 128.))], + 1 / hd_cpd_m[int(floor(i / 128.))]) * + hd_sqrtmm_m[int(floor(i / 128.))]; + hd_fp_v[i] = pow((float)hd_v[i] * hd_sv[int(floor(i / 128.))], + 1 / hd_cpd_v[int(floor(i / 128.))]) * + hd_sqrtmm_v[int(floor(i / 128.))]; + } + printFloatArrayToFile(h_p, M, N, "CPU_param.txt"); + printFloatArrayToFile(hd_p, M, N, "GPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "CPU_grad.txt"); + printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt"); + printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt"); + printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt"); + printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt"); + printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt"); + printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt"); + printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt"); + + printFloatArrayToFile(hd_cpd_m, 1, int(ceilf(M * N / 128.)), "GPU_cpd_m.txt"); + printFloatArrayToFile(hd_cpd_v, 1, int(ceilf(M * N / 128.)), "GPU_cpd_v.txt"); + printFloatArrayToFile(hd_sqrtmm_m, 1, int(ceilf(M * N / 128.)), + "GPU_sqrtmm_m.txt"); + printFloatArrayToFile(hd_sqrtmm_v, 1, int(ceilf(M * N / 128.)), + "GPU_sqrtmm_v.txt"); + + return 0.; +} + +int main() { + const int M = 1, N = 7; + float max_error = testMaxError(myKernelLauncher, M, N); +} diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile new file mode 100644 index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile @@ -0,0 +1,6 @@ +all: + nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90 +run: + ./nvcc_qoptim +clean: + rm -f nvcc_qoptim diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu new file mode 100644 index 0000000000000000000000000000000000000000..dde952654b71f11dc6de051a7089df38e14c6b57 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cg = cooperative_groups; +#define WARPSIZE 32 +#define QGROUPSIZE 128 +#define QUANT_MIN_VAL 1e-20 + +template +inline float fp8_dtype_max(const T &variable) { + if (std::is_same::value) { + return 448; + } else if (std::is_same::value) { + return 57344; + } else { + throw "Unsupported data format"; + } +} + +typedef enum { fp8_adamw } myCsrcKernels; + +void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg, + float *fp_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int M, + int N) { + for (int idx = 0; idx < M * N; idx++) { + fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx]; + fp_exp_avg_sq[idx] = + beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + float denom = + (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1; + float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } +} + +template +void printFloatArrayToFile(T *array, int M, int N, + const std::string &outputFileName) { + std::ofstream outputFile(outputFileName); + if (!outputFile.is_open()) { + std::cout << "Failed to open the file." << std::endl; + return; + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int index = i * N + j; + outputFile << std::setw(10) << std::right << std::fixed + << std::setprecision(6) << (float)array[index] << " "; + if (j == N - 1) { + outputFile << "\n"; + } + } + } +} + +template +__global__ void fp8_adamw_csrc(scalar_t *__restrict__ params, + scalar_t *__restrict__ grads, + __nv_fp8_e4m3 *__restrict__ exp_avg, + float *__restrict__ scale_exp_avg, + __nv_fp8_e4m3 *__restrict__ exp_avg_sq, + float *__restrict__ scale_exp_avg_sq, + float beta1, float beta2, float lr, float wd, + float eps, int step, int qgroup_size, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedSecondMaxVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal; + if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal; + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + } + if (lane == 0) shared_absmax_exp_avg = firstMaxVal; + if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal; + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + float fp8MaxVal = 448; + + shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL; + shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + + float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal; + float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + } +} + +template +void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg, + float *scale_exp_avg, __nv_fp8_e4m3 *exp_avg_sq, + float *scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, + int M, int N) { + if (algo == fp8_adamw) { + const int block_dim = 128; + int grid_dim = (M * N + qgroup_size - 1) / block_dim; + const dim3 gridDim(grid_dim); + const dim3 blockDim(block_dim); + printf("Yes!\n"); + fp8_adamw_csrc<<>>( + params, grads, exp_avg, scale_exp_avg, exp_avg_sq, scale_exp_avg_sq, + beta1, beta2, lr, wd, eps, step, qgroup_size, M * N, + int(floor(M * N / 128.))); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in kernel launch: " + << cudaGetErrorString(error) << std::endl; + return; + } + printf("Finish!\n"); + } +} + +float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *, + float *, __nv_fp8_e4m3 *, float *, float, + float, float, float, float, int, int, + int, int), + int M, int N) { + size_t size_param = M * N * sizeof(float); + size_t size_optim = M * N * sizeof(__nv_fp8_e4m3); + size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float); + + // host tensor + float *h_p, *h_g; + __nv_fp8_e4m3 *h_m, *h_v; + float *h_sm, *h_sv; + float *h_fp_m, *h_fp_v; + + // device tensor + float *d_p, *d_g; + __nv_fp8_e4m3 *d_m, *d_v; + float *d_sm, *d_sv; + + // device tensor transfer to host + float *hd_p, *hd_g; + __nv_fp8_e4m3 *hd_m, *hd_v; + float *hd_sm, *hd_sv; + float *hd_fp_m, *hd_fp_v; + + h_p = (float *)malloc(size_param); + h_g = (float *)malloc(size_param); + h_m = (__nv_fp8_e4m3 *)malloc(size_optim); + h_v = (__nv_fp8_e4m3 *)malloc(size_optim); + h_sm = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_fp_m = (float *)malloc(size_param); + h_fp_v = (float *)malloc(size_param); + cudaMalloc(&d_p, size_param); + cudaMalloc(&d_g, size_param); + cudaMalloc(&d_m, size_optim); + cudaMalloc(&d_v, size_optim); + cudaMalloc(&d_sm, size_scale); + cudaMalloc(&d_sv, size_scale); + hd_p = (float *)malloc(size_param); + hd_g = (float *)malloc(size_param); + hd_m = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_v = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_sm = (float *)malloc(size_scale); + hd_sv = (float *)malloc(size_scale); + hd_fp_m = (float *)malloc(size_param); + hd_fp_v = (float *)malloc(size_param); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data copy: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + srand(0); + // random initialization for CPU tensor + for (int i = 0; i < M * N; i++) { + h_p[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_g[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + } + for (int i = 0; i < int(ceilf(M * N / 128.)); i++) { + h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10)); + printf("scale is %f\n", h_sm[i]); + } + for (int i = 0; i < M * N; i++) { + h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))]; + h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))]; + } + float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8; + int step = 100, qgroup_size = 128; + + printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt"); + + cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice); + + fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data initialization: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + myGPUKernel(d_p, d_g, d_m, d_sm, d_v, d_sv, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost); + + for (int i = 0; i < M * N; i++) { + hd_fp_m[i] = (float)hd_m[i] * hd_sm[int(floor(i / 128.))]; + hd_fp_v[i] = (float)hd_v[i] * hd_sv[int(floor(i / 128.))]; + } + printFloatArrayToFile(h_p, M, N, "CPU_param.txt"); + printFloatArrayToFile(hd_p, M, N, "GPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "CPU_grad.txt"); + printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt"); + printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt"); + printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt"); + printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt"); + printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt"); + printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt"); + printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt"); + + return 0.; +} + +int main() { + const int M = 1, N = 7; + float max_error = testMaxError(myKernelLauncher, M, N); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57190c713309a19eafe35b4651e55d7c5bd576f7 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp @@ -0,0 +1,26 @@ +#include +#include + +void FP8_AdamW_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size // other parameters +); + +void FP8_AdamW(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size) { // other parameters + + FP8_AdamW_cuda(params, grads, exp_avg, scale_exp_avg, exp_avg_sq, + scale_exp_avg_sq, beta1, beta2, lr, wd, eps, step, + qgroup_size); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e652845af5d228237d3c8df758df1a4fe94bcd4 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu @@ -0,0 +1,163 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define QUANT_MIN_VAL 1e-20 + +namespace cg = cooperative_groups; +#define WARPSIZE 32 + +template +__global__ void fp8_adamw_cuda_kernel( + scalar_t* __restrict__ params, scalar_t* __restrict__ grads, + __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg, + __nv_fp8_e4m3* __restrict__ exp_avg_sq, + float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int total_elements, + int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedSecondMaxVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal; + if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal; + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + } + if (lane == 0) shared_absmax_exp_avg = firstMaxVal; + if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal; + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + float fp8MaxVal = 448; + + shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL; + shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + + float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal; + float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + } +} + +void FP8_AdamW_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size) { // other parameters + + // CUDA Blocks + int total_elements = params.numel(); + int total_scale_elements = scale_exp_avg.numel(); + AT_ASSERTM(qgroup_size == 128, + "Only Support 128 per-group quantization currently"); + const int block_dim = 128; // This should equal to the qgroup_size + int grid_dim = (total_elements + qgroup_size - 1) / block_dim; + AT_ASSERTM(grid_dim == scale_exp_avg.numel()); + AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel()); + const dim3 blocks(grid_dim); + + // Execution + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] { + fp8_adamw_cuda_kernel<<>>( + params.data_ptr(), grads.data_ptr(), + (__nv_fp8_e4m3*)exp_avg.data_ptr(), + scale_exp_avg.data_ptr(), + (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(), + scale_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, step, + qgroup_size, total_elements, total_scale_elements); + })); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d50f96dc9d4dc186b660d6b6ff56ff9df1c2df89 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp @@ -0,0 +1,34 @@ +#include +#include + +void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min // other parameters +); + +void FP8_AdamW_expand(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min) { // other parameters + + FP8_AdamW_expand_cuda(params, grads, exp_avg, scale_exp_avg, expand_exp_avg, + sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, + expand_exp_avg_sq, sqrtminmax_exp_avg_sq, beta1, beta2, + lr, wd, eps, step, qgroup_size, expand_min); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6130e3716fe5d73bf66cc6f915c3bf99a41f785b --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu @@ -0,0 +1,247 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define QUANT_MIN_VAL 1e-20 + +namespace cg = cooperative_groups; +#define WARPSIZE 32 + +template +__global__ void fp8_adamw_cuda_expand_kernel( + scalar_t* __restrict__ params, scalar_t* __restrict__ grads, + __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg, + float* __restrict__ expand_exp_avg, float* __restrict__ sqrtminmax_exp_avg, + __nv_fp8_e4m3* __restrict__ exp_avg_sq, + float* __restrict__ scale_exp_avg_sq, float* __restrict__ expand_exp_avg_sq, + float* __restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, int expand_min, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = sign_exp_avg * + powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) * + sqrtminmax_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + float_exp_avg_sq = + powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) * + sqrtminmax_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedFirstMinVal[32]; + __shared__ float sharedSecondMaxVal[32]; + __shared__ float sharedSecondMinVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float firstMinVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + float secondMinVal = fabsf(float_exp_avg_sq); + // Special Handel + if (idx >= total_elements) { + firstMinVal = __int_as_float(0x7f7fffff); + secondMinVal = __int_as_float(0x7f7fffff); + } + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) { + sharedFirstMaxVal[wid] = firstMaxVal; + sharedFirstMinVal[wid] = firstMinVal; + sharedSecondMaxVal[wid] = secondMaxVal; + sharedSecondMinVal[wid] = secondMinVal; + } + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmin_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + __shared__ float shared_absmin_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + firstMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + secondMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceFirstMinVal = + __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + float reduceSecondMinVal = + __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + } + if (lane == 0) { + shared_absmax_exp_avg = firstMaxVal; + shared_absmin_exp_avg = firstMinVal; + shared_absmax_exp_avg_sq = secondMaxVal; + shared_absmin_exp_avg_sq = secondMinVal; + } + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + // scaling factor before expanding + float fp8MaxVal = 448; + + // dynamic exponent quantization part + firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL; + firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL; + secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL; + + // calculate the ratio and make the scale to center + float firstRatio = firstMaxVal / firstMinVal; + float secondRatio = secondMaxVal / secondMinVal; + float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal); + float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal); + + // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal, + // float_exp_avg); + + // since we use x^k expander, calculate the optimal expanding factor + float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2; + float firstExp = + floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) / + expand_min; // expand_min is set to 8 for example, then the firstExp is + // the multiple of 1/8 + float secondExp = + floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) / + expand_min; + + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = + sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp); + float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp); + + // correspondingly, change the scaling factor + float new_scale_exp_avg = + powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal; + float new_scale_exp_avg_sq = + powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + expand_exp_avg[scale_idx] = firstExp; + expand_exp_avg_sq[scale_idx] = secondExp; + sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax; + sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax; + } +} + +void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, + int expand_min) { // other parameters + + // CUDA Blocks + int total_elements = params.numel(); + int total_scale_elements = scale_exp_avg.numel(); + AT_ASSERTM(qgroup_size == 128, + "Only Support 128 per-group quantization currently"); + const int block_dim = 128; // This should equal to the qgroup_size + int grid_dim = (total_elements + qgroup_size - 1) / block_dim; + AT_ASSERTM(grid_dim == scale_exp_avg.numel()); + AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel()); + const dim3 blocks(grid_dim); + + // Execution + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] { + fp8_adamw_cuda_expand_kernel<<>>( + params.data_ptr(), grads.data_ptr(), + (__nv_fp8_e4m3*)exp_avg.data_ptr(), + scale_exp_avg.data_ptr(), expand_exp_avg.data_ptr(), + sqrtminmax_exp_avg.data_ptr(), + (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(), + scale_exp_avg_sq.data_ptr(), + expand_exp_avg_sq.data_ptr(), + sqrtminmax_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, + step, qgroup_size, expand_min, total_elements, + total_scale_elements); + })); +} diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h new file mode 100644 index 0000000000000000000000000000000000000000..c04970da0dd564ea27e1e6203a38206d9506cac8 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h @@ -0,0 +1,14 @@ +#ifndef FP8_ADAMW +#define FP8_ADAMW + +#include + +void FP8_AdamW(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size); + +#endif // FP8_ADAMW diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h new file mode 100644 index 0000000000000000000000000000000000000000..e1fb69825933154b1a867efaecd54cbdf100201b --- /dev/null +++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h @@ -0,0 +1,18 @@ +#ifndef FP8_ADAMW_CONPAND +#define FP8_ADAMW_CONPAND + +#include + +void FP8_AdamW_expand(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min); + +#endif // FP8_ADAMW_CONPAND diff --git a/llava/model/coat/optimizer/kernels/setup.py b/llava/model/coat/optimizer/kernels/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8453efeaa472f3b094a2c157acbebf4c571eb09d --- /dev/null +++ b/llava/model/coat/optimizer/kernels/setup.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="qoptim_cuda", + ext_modules=[ + CUDAExtension( + name="qoptim_cuda", + sources=[ + "fp8_adamw_cuda.cpp", + "fp8_adamw_cuda_kernel.cu", + "fp8_adamw_expand_cuda.cpp", + "fp8_adamw_expand_cuda_kernel.cu", + "bindings.cpp", + ], + # include_dirs=[ + # 'include' + # ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-std=c++17", + "-gencode=arch=compute_90,code=compute_90", + "-DTORCH_USE_CUDA_DSA", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ] + }, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/llava/model/configuration_llava.py b/llava/model/configuration_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..b413354ab70581faf619a632e1934a1a956250d0 --- /dev/null +++ b/llava/model/configuration_llava.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal, Optional + +from pydantic import BaseModel, Field +from transformers import PretrainedConfig + + +class LlavaConfig(PretrainedConfig): + model_type = "llava" + + def __init__( + self, + llm_cfg=None, + vision_tower_cfg=None, + speech_tower_cfg=None, + sound_tower_cfg=None, + mm_projector_cfg=None, + speech_mm_projector_cfg=None, + sound_mm_projector_cfg=None, + architectures=None, + resume_path=None, + hidden_size=None, + mm_hidden_size=None, + speech_hidden_size=None, + sound_hidden_size=None, + image_aspect_ratio=None, + num_video_frames=None, + fps=None, + mm_vision_select_layer=None, + mm_vision_select_feature=None, + mm_use_im_start_end=False, + mm_use_im_patch_token=False, + mm_projector_lr=None, + speech_mm_projector_lr=None, + sound_mm_projector_lr=None, + vision_tower_lr=None, + speech_tower_lr=None, + sound_tower_lr=None, + vision_resolution=None, + interpolate_mode=None, + s2=None, + dynamic_s2=None, + s2_scales=None, + s2_max_split_size=None, + s2_resize_output_to_scale_idx=0, + min_tiles: Optional[int] = 1, + max_tiles: Optional[int] = 12, + video_max_tiles: Optional[int] = 1, + num_time_tokens=None, + time_token_format=None, + image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}', + video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}', + speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}', + sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}', + **kwargs, + ): + super().__init__() + self.architectures = architectures + self.llm_cfg = llm_cfg + self.vision_tower_cfg = vision_tower_cfg + self.speech_tower_cfg = speech_tower_cfg + self.sound_tower_cfg = sound_tower_cfg + self.mm_projector_cfg = mm_projector_cfg + self.speech_mm_projector_cfg = speech_mm_projector_cfg + self.sound_mm_projector_cfg = sound_mm_projector_cfg + self.resume_path = resume_path + + self.hidden_size = hidden_size + self.mm_hidden_size = mm_hidden_size + self.speech_hidden_size = speech_hidden_size + self.sound_hidden_size = sound_hidden_size + self.image_aspect_ratio = image_aspect_ratio + self.num_video_frames = num_video_frames + self.fps = fps + self.mm_vision_select_layer = mm_vision_select_layer + self.mm_vision_select_feature = mm_vision_select_feature + self.mm_use_im_start_end = mm_use_im_start_end + self.mm_use_im_patch_token = mm_use_im_patch_token + self.mm_projector_lr = mm_projector_lr + self.speech_mm_projector_lr = speech_mm_projector_lr + self.sound_mm_projector_lr = sound_mm_projector_lr + self.vision_tower_lr = vision_tower_lr + self.speech_tower_lr = speech_tower_lr + self.sound_tower_lr = sound_tower_lr + self.vision_resolution = vision_resolution + self.interpolate_mode = interpolate_mode + self.s2 = s2 + self.dynamic_s2 = dynamic_s2 + self.s2_scales = s2_scales + self.s2_max_split_size = s2_max_split_size + self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx + self.min_tiles = min_tiles + self.max_tiles = max_tiles + self.video_max_tiles = video_max_tiles + self.num_time_tokens = num_time_tokens + self.time_token_format = time_token_format + + self.image_encoder = image_encoder + self.video_encoder = video_encoder + self.speech_encoder = speech_encoder + self.sound_encoder = sound_encoder + + +class JsonSchemaResponseFormat(BaseModel): + schema_: str = Field(alias="schema") + + +class ResponseFormat(BaseModel): + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None diff --git a/llava/model/deprecate_consolidate.py b/llava/model/deprecate_consolidate.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb9619df19e4e8791920b4c22d38763538454b1 --- /dev/null +++ b/llava/model/deprecate_consolidate.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Usage: +python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate +""" +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llava.model import * +from llava.model.utils import auto_upgrade + + +def consolidate_ckpt(src_path, dst_path): + print("Loading model") + auto_upgrade(src_path) + src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) + src_model.save_pretrained(dst_path) + src_tokenizer.save_pretrained(dst_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, required=True) + parser.add_argument("--dst", type=str, required=True) + + args = parser.parse_args() + + consolidate_ckpt(args.src, args.dst) diff --git a/llava/model/encoders/__init__.py b/llava/model/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e80d23d6693e11f0779e342c9b6384d78b4c79 --- /dev/null +++ b/llava/model/encoders/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .image import * +from .video import * +from .speech import * +from .sound import * + diff --git a/llava/model/encoders/__pycache__/__init__.cpython-310.pyc b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51821273a8633724bacf886f82c5d3f6dccef0f2 Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/llava/model/encoders/__pycache__/__init__.cpython-311.pyc b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91555cc8635606593bafd5cbfc61fff128718d6e Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc differ diff --git a/llava/model/encoders/__pycache__/base.cpython-310.pyc b/llava/model/encoders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1a5e6a8761723bbebdd2cec99f1d26496e05ce0 Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-310.pyc differ diff --git a/llava/model/encoders/__pycache__/base.cpython-311.pyc b/llava/model/encoders/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e0467e667be5f975f22ccd0d080fe75c8ea560c Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-311.pyc differ diff --git a/llava/model/encoders/base.py b/llava/model/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9230c954a0f711b063e9b9929cbe0e3d40b55a --- /dev/null +++ b/llava/model/encoders/base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from torch import nn + +__all__ = ["BaseEncoder"] + + +class BaseEncoder(nn.Module): + def __init__(self, parent: nn.Module) -> None: + super().__init__() + self._parent = [parent] + + @property + def parent(self) -> nn.Module: + return self._parent[0] diff --git a/llava/model/encoders/image/__init__.py b/llava/model/encoders/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/image/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/image/basic.py b/llava/model/encoders/image/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..6428e38f56ffe421edd25bf45f378b25849874f8 --- /dev/null +++ b/llava/model/encoders/image/basic.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicImageEncoder"] + + +class BasicImageEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, images: List[torch.Tensor], config: Dict[str, Any], frame_times=None) -> List[torch.Tensor]: + images = torch.stack(images, dim=0) + features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) + + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/sound/__init__.py b/llava/model/encoders/sound/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/sound/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/sound/basic.py b/llava/model/encoders/sound/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..beb5ba5af1ba40fc9a9cba5b41e37673a6cdfeb2 --- /dev/null +++ b/llava/model/encoders/sound/basic.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicSoundEncoder"] + + +class BasicSoundEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + features = features.to(self.parent.device) + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], masks: Dict[str, Any]) -> List[torch.Tensor]: + sounds = torch.stack(sounds, dim=0) + masks = torch.stack(masks, dim=0) + features = self.parent.encode_sound(sounds, masks) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/speech/__init__.py b/llava/model/encoders/speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/speech/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/speech/basic.py b/llava/model/encoders/speech/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..68226501c5fcfc8d0cc6030ca71b324a7cb7f360 --- /dev/null +++ b/llava/model/encoders/speech/basic.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicSpeechEncoder"] + + +class BasicSpeechEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, speeches: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + speeches = torch.stack(speeches, dim=0) + features = self.parent.encode_speech(speeches) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/video/__init__.py b/llava/model/encoders/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b00796ffe81fd8d1cb50cd604a353e8c4e8199ad --- /dev/null +++ b/llava/model/encoders/video/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * +from .tsp import * diff --git a/llava/model/encoders/video/basic.py b/llava/model/encoders/video/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8b84c6917874686d071ef60a5bd6f2cab0a7f9e4 --- /dev/null +++ b/llava/model/encoders/video/basic.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicVideoEncoder"] + + +class BasicVideoEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) + features = torch.cat([start_embeds, features], dim=1) + if end_token_embeds is not None: + end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) + features = torch.cat([features, end_embeds], dim=1) + return features.flatten(0, 1) + + def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + num_frames = [video.shape[0] for video in videos] + images = torch.cat(videos, dim=0) + features = self.parent.encode_images(images) + features = torch.split(features, num_frames) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/video/tsp.py b/llava/model/encoders/video/tsp.py new file mode 100644 index 0000000000000000000000000000000000000000..efa06530b3aea80a994efb5bd3d5db70f46bd577 --- /dev/null +++ b/llava/model/encoders/video/tsp.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from .basic import BasicVideoEncoder + +__all__ = ["TSPVideoEncoder"] + + +def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor: + return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1) + + +class TSPVideoEncoder(BasicVideoEncoder): + def __init__( + self, + parent: torch.nn.Module, + pool_sizes: List[Tuple[int, int, int]], + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + sep_tokens: Optional[str] = None, + ) -> None: + super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens) + self.pool_sizes = pool_sizes + self.sep_tokens = sep_tokens + + def _process_features( + self, + inputs: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + sep_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + nt, ns = inputs.shape[:2] + nl = int(ns**0.5) + outputs = [] + for pool_size in self.pool_sizes: + features = inputs.view(nt, nl, nl, -1) + for dim, p in enumerate(pool_size): + features = pool(features, p, dim=dim) + features = features.flatten(1, 2) + features = super()._process_features( + features, + start_token_embeds=start_token_embeds, + end_token_embeds=end_token_embeds, + ) + if sep_token_embeds is not None: + features = torch.cat([features, sep_token_embeds], dim=0) + outputs.append(features) + return torch.cat(outputs, dim=0) + + def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + num_frames = [video.shape[0] for video in videos] + images = torch.cat(videos, dim=0) + features = self.parent.encode_images(images) + features = torch.split(features, num_frames) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + sep_token_embeds=self.embed_tokens(self.sep_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/language_model/builder.py b/llava/model/language_model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6936800c3c06666fc9d3d02f7424d0eb6409ea39 --- /dev/null +++ b/llava/model/language_model/builder.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math +import os +import os.path as osp +import warnings +from dataclasses import asdict +from typing import Tuple + +import torch +from huggingface_hub import file_exists, repo_exists +from huggingface_hub.utils import HFValidationError +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) + + +from llava.constants import MEDIA_TOKENS +from llava.model.utils import packing +from llava.utils.logging import logger +from llava.utils.tokenizer import infer_stop_tokens + + +def has_tokenizer(repo_id_or_path: str) -> bool: + # Check if the tokenizer is in a local directory + if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): + return True + + # Check if the tokenizer is in a Hugging Face Hub repo + try: + return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") + except HFValidationError: + return False + + +def context_length_extension(config): + orig_ctx_len = getattr(config, "max_position_embeddings", None) + model_max_length = getattr(config, "model_max_length", None) + if orig_ctx_len and model_max_length > orig_ctx_len: + print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") + scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + return config + + +def build_llm_and_tokenizer( + model_name_or_path: str, + config: PretrainedConfig, + attn_implementation=None, + model_max_length=None, + *args, + **kwargs, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + # print(model_name_or_path) + llm_cfg = AutoConfig.from_pretrained(model_name_or_path) + llm_cfg._attn_implementation = attn_implementation + llm_cfg.model_max_length = model_max_length + if model_max_length is not None: + context_length_extension(llm_cfg) + + # Quantization related + quantization_restore_from_checkpoint = False + if kwargs.get("quantize_model_class") is not None: + assert kwargs.get("model_args") is not None + quantize_model_class = kwargs.pop("quantize_model_class", None) + model_args = kwargs.pop("model_args", None) + + if quantize_model_class == "QLlamaForCausalLM": # TODO: Also change the name of this class + from .qllama import QLlamaConfig + + llm_cfg.architectures = "QLlamaForCausalLM" + _attn_implementation = llm_cfg._attn_implementation + llm_cfg = QLlamaConfig(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + elif quantize_model_class == "QMemLlamaForCausalLM": # TODO: Also change the name of this class + from .qmemllama import QMemLlamaConfig + + llm_cfg.architectures = "QMemLlamaForCausalLM" + llm_cfg = QMemLlamaConfig(**llm_cfg.to_dict()) + elif quantize_model_class == "FP8LinearQwen2ForCausalLM": + from .configuration_quantize import QuantizationConfig + from .fp8linearqwen2 import FP8LinearQwen2Config + + llm_cfg.architectures = "FP8LinearQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8LinearQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + + elif quantize_model_class == "FP8ActivationQwen2ForCausalLM": + from ..coat.activation.models._fp8_quantization_config import QuantizationConfig + from .fp8activationqwen2 import FP8ActivationQwen2Config + + quantization_restore_from_checkpoint = True + + llm_cfg.architectures = "FP8ActivationQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8ActivationQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + + elif quantize_model_class == "FP8ActivationResidualQwen2ForCausalLM": + from ..coat.activation.models._fp8_quantization_config import QuantizationConfig + from .fp8activationresidualqwen2 import FP8ActivationResidualQwen2Config + + quantization_restore_from_checkpoint = True + + llm_cfg.architectures = "FP8ActivationResidualQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8ActivationResidualQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + else: + raise ValueError(f"{quantize_model_class} is not supported quantize_model_class.") + + kwargs.pop("quantize_model_class", None) + + if quantize_model_class in [ + "FP8LinearQwen2ForCausalLM", + "FP8ActivationQwen2ForCausalLM", + "FP8ActivationResidualQwen2ForCausalLM", + ]: # Remove the quantization args from llm_cfg and make it a independent config + llm_cfg.update(model_args_dict) + else: + llm_cfg.update(asdict(model_args)) + # print(model_args) + + if quantization_restore_from_checkpoint: + fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) + + llm = AutoModelForCausalLM.from_pretrained( + model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs + ) + + else: + llm = AutoModelForCausalLM.from_pretrained( + model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs + ) + packing.patch(llm) + + # Locate the tokenizer. + llm_path = model_name_or_path + if not has_tokenizer(llm_path): + llm_path = osp.join(llm_path, "llm") + if not has_tokenizer(llm_path): + raise ValueError(f"Cannot find tokenizer in {llm_path}.") + + tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) + if model_max_length is not None: + tokenizer.model_max_length = model_max_length + + # Load chat template if specified. + if getattr(config, "chat_template", None) is not None: + logger.info(f"Using chat template: {config.chat_template}") + fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") + with open(fpath) as fd: + chat_template = fd.read() + tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") + + # Set stop tokens for the tokenizer + tokenizer.stop_tokens = infer_stop_tokens(tokenizer) + tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) + + # Add media tokens to the tokenizer + tokenizer.media_tokens = MEDIA_TOKENS + tokenizer.media_token_ids = {} + for name, token in MEDIA_TOKENS.items(): + tokenizer.add_tokens([token], special_tokens=True) + tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) + + # TODO(ligeng): is this necessary for llava? + config.hidden_size = llm.config.hidden_size + return llm, tokenizer diff --git a/llava/model/language_model/chat_templates/mistral.jinja b/llava/model/language_model/chat_templates/mistral.jinja new file mode 100644 index 0000000000000000000000000000000000000000..774073e68db8dc9ffc65daab8d948c0079235a5d --- /dev/null +++ b/llava/model/language_model/chat_templates/mistral.jinja @@ -0,0 +1,11 @@ +{{ bos_token }} + +{% for message in messages if message['content'] is not none %} + {% if message['role'] == 'system' %} + {{ message['content'] | trim + '\n\n' }} + {% elif message['role'] == 'user' %} + {{ '[INST] ' + message['content'] | trim + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + message['content'] | trim + eos_token }} + {% endif %} +{% endfor %} diff --git a/llava/model/language_model/chat_templates/qwen2.jinja b/llava/model/language_model/chat_templates/qwen2.jinja new file mode 100644 index 0000000000000000000000000000000000000000..d3b0fd8c09e426c3a8c6e29edde3e56586c961bb --- /dev/null +++ b/llava/model/language_model/chat_templates/qwen2.jinja @@ -0,0 +1,11 @@ +{% if messages[0]['role'] != 'system' %} + {{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} +{% endif %} + +{% for message in messages if message['content'] is not none %} + {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} +{% endif %} diff --git a/llava/model/language_model/configuration_quantize.py b/llava/model/language_model/configuration_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..9b226923e1ab06e6f868455e72dd8430386c8ac8 --- /dev/null +++ b/llava/model/language_model/configuration_quantize.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass + +from transformers import PretrainedConfig + + +@dataclass +class QuantizationConfig: + quantize_model: str = "false" + symm: bool = True + epsilon: float = 1e-10 + fabit: str = "E4M3" + fwbit: str = "E4M3" + bobit: str = "E5M2" + row_blocksize: int = -1 + col_blocksize: int = -1 + qchoice: str = "none" + pad_to_multiple_of: int = 0 + + def __init__( + self, + quantize_model, + symm, + epsilon, + fabit, + fwbit, + bobit, + row_blocksize, + col_blocksize, + qchoice, + pad_to_multiple_of, + **kwargs, + ): + super().__init__() + self.quantize_model = quantize_model + self.symm = symm + self.epsilon = epsilon + self.fabit = fabit + self.fwbit = fwbit + self.bobit = bobit + self.row_blocksize = row_blocksize + self.col_blocksize = col_blocksize + self.qchoice = qchoice + self.pad_to_multiple_of = pad_to_multiple_of + + +# class QuantizationConfig(PretrainedConfig): +# def __init__( +# self, +# quantize_model="false", +# symm=True, +# epsilon=1e-10, +# fabit="E4M3", +# fwbit="E4M3", +# bobit="E5M2", +# row_blocksize=-1, +# col_blocksize=-1, +# qchoice="none", +# pad_to_multiple_of=0, +# **kwargs, +# ): +# super().__init__() +# self.quantize_model = quantize_model +# self.symm = symm +# self.epsilon = epsilon +# self.fabit = fabit +# self.fwbit = fwbit +# self.bobit = bobit +# self.row_blocksize = row_blocksize +# self.col_blocksize = col_blocksize +# self.qchoice = qchoice +# self.pad_to_multiple_of = pad_to_multiple_of diff --git a/llava/model/language_model/fp8_qwen2_convert_from_hf.py b/llava/model/language_model/fp8_qwen2_convert_from_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..71a7f9f602491851c21cf7a7702a7174b414bfee --- /dev/null +++ b/llava/model/language_model/fp8_qwen2_convert_from_hf.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import os +from dataclasses import asdict, dataclass, field +from typing import Optional + +import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from .fp8activationqwen2 import FP8ActivationQwen2Config, make_state_dict_compatible + + +@dataclass +class ConvertArguments: + model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"}) + save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"}) + cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"}) + + +def download_and_convert_qwen2(convert_args: ConvertArguments, quantization_args: QuantizationConfig): + """ + Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`, + and saves the converted model. + + Args: + model_name (str): The model name or path to download the LLaMA model. + save_path (str): The path where the converted model weights will be saved. + cache_dir (Optional[str]): Directory to cache the model. Defaults to None. + + Returns: + None + """ + model_name = convert_args.model_name + save_path = convert_args.save_path + cache_dir = convert_args.cache_dir + + # Step 1: Download the original LLaMA model + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 2: Initialize the model configuration for FP8 or other custom config + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 3: Apply make_state_dict_compatible to convert weights + compatible_state_dict = make_state_dict_compatible(model.state_dict()) + + # Step 4: Create a new model instance with compatible configuration + fp8_config = FP8ActivationQwen2Config(**config.to_dict()) + fp8_config.coat_fp8_args = asdict(quantization_args) + fp8_config._name_or_path = save_path + + converted_model = AutoModelForCausalLM.from_config(fp8_config, torch_dtype=torch.bfloat16) + converted_model.load_state_dict(compatible_state_dict) + + # Step 5: Save the converted model and configuration using save_pretrained + os.makedirs(save_path, exist_ok=True) + converted_model.save_pretrained(save_path) + print(f"Converted model saved at {save_path}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8 + convert_args, quantization_args = parser.parse_args_into_dataclasses() + + # Call the function with parsed arguments + download_and_convert_qwen2(convert_args, quantization_args) diff --git a/llava/model/language_model/fp8activationqwen2.py b/llava/model/language_model/fp8activationqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..34bde452da53042dc55b39ae204e541e1eb1f920 --- /dev/null +++ b/llava/model/language_model/fp8activationqwen2.py @@ -0,0 +1,1722 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +import os +from dataclasses import asdict, dataclass, field +from fnmatch import fnmatch +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +# FP8 related +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule +from ..coat.activation.models._fp8manager import FP8Manager +from ..coat.activation.real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ..liger.cross_entropy import LigerForCausalLMLoss +from ..qlinear_te import QLinearTE + +# from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8ActivationQwen2Config(Qwen2Config): + model_type = "fp8activation_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +class FP8ActivationQwen2BeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + # Directly use the corresponding weight + weight1, _, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, _, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, _, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + else: + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + + weight1, weight2, weight3 = None, None, None + return _FP8ActivationQwen2BeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + None, + weight1_s, + self.q_proj.bias, + self.k_proj.weight, + weight2, + None, + weight2_s, + self.k_proj.bias, + self.v_proj.weight, + weight3, + None, + weight3_s, + self.v_proj.bias, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _FP8ActivationQwen2BeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("This should be implemented in the future") + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _FP8ActivationQwen2BeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight1_bias, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight2_bias, + weight3_origin, + weight3, + weight3_t, + weight3_s, + weight3_bias, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert ( + weight1_t is None and weight2_t is None and weight3_t is None + ) # we should not pass W^T to here if weight memory efficient + if weight1 is None: + assert weight2 is None and weight3 is None + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + ctx.bias = weight1_bias, weight2_bias, weight3_bias + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, fp_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + weight1_bias, weight2_bias, weight3_bias = ctx.bias + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Gradient of Bias TODO: make this better + if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None: + att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0) + att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0) + att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0) + else: + att_q_bg = None + att_k_bg = None + att_v_bg = None + + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + re_g, + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_q_bg, + att_k_wg, + None, + None, + None, + att_k_bg, + att_v_wg, + None, + None, + None, + att_v_bg, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationQwen2AfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + weight4, _, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + else: + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + weight1 = None + + return _FP8ActivationQwen2AfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationQwen2AfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _FP8ActivationQwen2AfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs + ): + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4_t is None + if weight4 is None: # the second batch + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size) + + if time_bench: + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if int(os.environ.get("LOCAL_RANK")) == 0: + print( + f"[AfterAt] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " + f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " + f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " + f" Input shape: {re_x.shape}" + ) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class FP8ActivationQwen2MLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + # Directly use the corresponding weight + weight1, _, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, _, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, _, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + else: + weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + weight1, weight2, weight3 = None, None, None + return _FP8ActivationQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + None, + weight1_s, + self.up_proj.weight, + weight2, + None, + weight2_s, + self.down_proj.weight, + weight3, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _FP8ActivationQwen2MLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None # memory efficient + if weight1 is None: + assert weight2 is None and weight3 is None + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # Element-wise Multiplication gradient checkpointing + # silu gradient checkpointing + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationQwen2AttentionWithoutLinear(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationQwen2FlashAttention2WithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationQwen2SdpaAttentionWithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + return attn_output, None, past_key_value + + +FP8LINEARQWEN2_ATTENTION_CLASSES = { + "eager": FP8ActivationQwen2AttentionWithoutLinear, + "flash_attention_2": FP8ActivationQwen2FlashAttention2WithoutLinear, + "sdpa": FP8ActivationQwen2SdpaAttentionWithoutLinear, +} + + +class FP8ActivationQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8ActivationQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = FP8ActivationQwen2BeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = FP8ActivationQwen2AfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = FP8ActivationQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual, query_states, key_states, value_states = self.BeforeAttention( + hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states) + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + # Residual Connection, LayerNorm, and the whole MLP module + hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + if time_bench: + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if int(os.environ.get("LOCAL_RANK")) == 0: + print( + f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " + f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " + f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " + f" Input shape: {hidden_states.shape}" + ) + + outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),) + + # if int(os.environ.get("LOCAL_RANK")) == 0: + # import IPython + # IPython.embed() + # else: + # import time + # time.sleep(1000) + # if output_attentions: + # outputs += (self_attn_weights,) + + # if use_cache: + # outputs += (present_key_value,) + + return outputs + + +class FP8ActivationQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8ActivationQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8ActivationQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8ActivationQwen2Model(FP8ActivationQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8ActivationQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8ActivationQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + self.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # import warnings + + # # ๅฐ†ๆ‰€ๆœ‰ UserWarning ๆๅ‡ไธบๅผ‚ๅธธ + # warnings.simplefilter("error", UserWarning) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # if int(os.environ.get("LOCAL_RANK")) == 0: + # print(f"FP8 Layer Idx: {layer_idx}, Input Shape: {hidden_states.shape}, Memory Allocated: {torch.cuda.memory_allocated() // 1024 ** 2} MB | Peak: {torch.cuda.max_memory_allocated() // 1024 ** 2} MB") + + # NOTE: We explicitly force the LLM do not use Gradient Checkpointing + # exit(0) + # if self.gradient_checkpointing and self.training: + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # quant_hidden_states, + # scale_hidden_states, + # causal_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # cache_position, + # position_embeddings, + # ) + # else: + layer_outputs = decoder_layer( + hidden_states, + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8ActivationQwen2ForCausalLM(FP8ActivationQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8ActivationQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @property + @lru_cache + def loss_function(self): + return LigerForCausalLMLoss + + forward = Qwen2ForCausalLM.forward + + +AutoConfig.register("fp8activation_qwen2", FP8ActivationQwen2Config) +AutoModel.register(FP8ActivationQwen2Config, FP8ActivationQwen2Model) +AutoModelForCausalLM.register(FP8ActivationQwen2Config, FP8ActivationQwen2ForCausalLM) + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict diff --git a/llava/model/language_model/fp8activationresidualqwen2.py b/llava/model/language_model/fp8activationresidualqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bccf84346e0e2b59ed6dc56ca1a6799bc11c5c5 --- /dev/null +++ b/llava/model/language_model/fp8activationresidualqwen2.py @@ -0,0 +1,1586 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +import os +from dataclasses import asdict, dataclass, field +from fnmatch import fnmatch +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +# FP8 related +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule +from ..coat.activation.models._fp8manager import FP8Manager +from ..coat.activation.real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ..qlinear_te import QLinearTE +from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8ActivationResidualQwen2Config(Qwen2Config): + model_type = "fp8activationresidual_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +class FP8ActivationResidualQwen2BeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__( + self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None + ): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + + def forward(self, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply( + x, + s, + self.q_proj.weight, + None, + None, + weight1_s, + self.q_proj.bias, + self.k_proj.weight, + None, + None, + weight2_s, + self.k_proj.bias, + self.v_proj.weight, + None, + None, + weight3_s, + self.v_proj.bias, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply( + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("This should be implemented in the future") + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _FP8ActivationResidualQwen2BeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight1_bias, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight2_bias, + weight3_origin, + weight3, + weight3_t, + weight3_s, + weight3_bias, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + ctx.bias = weight1_bias, weight2_bias, weight3_bias + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return in_x, in_s, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, q_grad, s_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + weight1_bias, weight2_bias, weight3_bias = ctx.bias + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Gradient of Bias TODO: make this better + if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None: + att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0) + att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0) + att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0) + else: + att_q_bg = None + att_k_bg = None + att_v_bg = None + + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + in_g, in_sg, in_sg_g16 = fp8_add_Ig16_Ifp_Opt( + q_grad, s_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_q_bg, + att_k_wg, + None, + None, + None, + att_k_bg, + att_v_wg, + None, + None, + None, + att_v_bg, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationResidualQwen2AfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward(self, re_qx, re_sx, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + + return _FP8ActivationResidualQwen2AfterAttentionResidual.apply( + re_qx, + re_sx, + in_x, + self.o_proj.weight, + None, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationResidualQwen2AfterAttentionResidual.apply( + re_qx, + re_sx, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _FP8ActivationResidualQwen2AfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_qx, + re_sx, + flash_x, + weight4_origin, + weight4, + weight4_t, + weight4_s, + group_size, + fwobits, + layer_id, + config, + qargs, + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4 is None # memory efficient + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ig16_Ifp_Ofp_Og16(re_qx, re_sx, fc4_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class FP8ActivationResidualQwen2MLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch) + + return _FP8ActivationResidualQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + None, + None, + weight1_s, + self.up_proj.weight, + None, + None, + weight2_s, + self.down_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationResidualQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _FP8ActivationResidualQwen2MLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None and weight2 is None and weight3 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + out_x, out_s = fp8_add_Ifp_Ifp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return out_x, out_s + + @staticmethod + def backward(ctx, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + out_g, out_gs_max, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationResidualQwen2AttentionWithoutLinear(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationResidualQwen2FlashAttention2WithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationResidualQwen2SdpaAttentionWithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + return attn_output, None, past_key_value + + +FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES = { + "eager": FP8ActivationResidualQwen2AttentionWithoutLinear, + "flash_attention_2": FP8ActivationResidualQwen2FlashAttention2WithoutLinear, + "sdpa": FP8ActivationResidualQwen2SdpaAttentionWithoutLinear, +} + + +class FP8ActivationResidualQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8ActivationResidualQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = FP8ActivationResidualQwen2BeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = FP8ActivationResidualQwen2AfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = FP8ActivationResidualQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual_quant, residual_scale, query_states, key_states, value_states = self.BeforeAttention( + quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention( + residual_quant, residual_scale, hidden_states + ) + + # Residual Connection, LayerNorm, and the whole MLP module + quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + outputs = ((quant_hidden_states, scale_hidden_states),) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class FP8ActivationResidualQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8ActivationResidualQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8ActivationResidualQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8ActivationResidualQwen2Model(FP8ActivationResidualQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8ActivationResidualQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + quant_hidden_states, + scale_hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8ActivationResidualQwen2ForCausalLM(FP8ActivationResidualQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8ActivationResidualQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = Qwen2ForCausalLM.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation + + +AutoConfig.register("fp8activationresidual_qwen2", FP8ActivationResidualQwen2Config) +AutoModel.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2Model) +AutoModelForCausalLM.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2ForCausalLM) + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict diff --git a/llava/model/language_model/fp8linearqwen2.py b/llava/model/language_model/fp8linearqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae93bb8d6798e9f327198abc405866e4a1527e2 --- /dev/null +++ b/llava/model/language_model/fp8linearqwen2.py @@ -0,0 +1,356 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from ..liger.cross_entropy import LigerForCausalLMLoss +from ..qlinear_te import QLinearTE +from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8LinearQwen2Config(Qwen2Config): + model_type = "fp8linear_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 +class FP8LinearQwen2MLP(Qwen2MLP): + def __init__(self, config, layer_idx): + super().__init__(config) + # self.gate_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.up_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.down_proj = te.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.gate_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + self.up_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + self.down_proj = QLinearTE( + self.intermediate_size, self.hidden_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class FP8LinearQwen2Attention(Qwen2Attention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: FP8LinearQwen2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.q_proj = QLinearTE( + self.hidden_size, + self.num_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.k_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.v_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.o_proj = QLinearTE( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + + forward = Qwen2Attention.forward + + +class FP8LinearQwen2FlashAttention2(FP8LinearQwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + forward = Qwen2FlashAttention2.forward + + +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 +class FP8LinearQwen2SdpaAttention(FP8LinearQwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + forward = Qwen2SdpaAttention.forward + + +FP8LINEARQWEN2_ATTENTION_CLASSES = { + "eager": FP8LinearQwen2Attention, + "flash_attention_2": FP8LinearQwen2FlashAttention2, + "sdpa": FP8LinearQwen2SdpaAttention, +} + + +class FP8LinearQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8LinearQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = FP8LinearQwen2MLP(config, layer_idx) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + forward = Qwen2DecoderLayer.forward + + +class FP8LinearQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8LinearQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8LinearQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8LinearQwen2Model(FP8LinearQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8LinearQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8LinearQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + forward = Qwen2Model.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8LinearQwen2ForCausalLM(FP8LinearQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8LinearQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @property + @lru_cache + def loss_function(self): + return LigerForCausalLMLoss + + forward = Qwen2ForCausalLM.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation + + +AutoConfig.register("fp8linear_qwen2", FP8LinearQwen2Config) +AutoModel.register(FP8LinearQwen2Config, FP8LinearQwen2Model) +AutoModelForCausalLM.register(FP8LinearQwen2Config, FP8LinearQwen2ForCausalLM) diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fbe3f845abe162edd4d891cca35a752b0966da --- /dev/null +++ b/llava/model/language_model/llava_llama.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +import os +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from llava.model.loss import soft_cross_entropy +from llava.model.utils.packing import set_seqlens_in_batch +from llava.train.sequence_parallel.globals import get_pg_manager +from llava.utils.logging import logger + +from ...train.utils import calculate_loss_weight +from ..configuration_llava import LlavaConfig +from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel + + +class LlavaLlamaConfig(LlavaConfig): + model_type = "llava_llama" + + +# FIXME we will follow the convention to add a new class for CausalLM in the future +class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel): + config_class = LlavaLlamaConfig + main_input_name = "input_embeds" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + + def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: + super().__init__(config) + self.init_vlm(config=config, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + if hasattr(cls, "load_pretrained"): + return cls.load_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + return super(LlavaLlamaModel).from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + media: Optional[Dict[str, List[torch.Tensor]]] = None, + images: Optional[torch.FloatTensor] = None, + media_config: Optional[List] = None, + attention_mask: Optional[torch.Tensor] = None, + media_meta: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + packing: bool = True, + force_packing: bool = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + dpo_forward: bool = False, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + self.freezed_module_patch() + + if images is not None: + if media is not None: + raise ValueError("Both 'media' and 'images' are provided. Please provide only one.") + logger.warning("The 'images' argument is deprecated. Please use 'media' instead.") + media = {"image": images} + + if media_config is None: + media_config = defaultdict(dict) + if inputs_embeds is None: + inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask,media_meta) + + if force_packing or (packing and self.training and not dpo_forward): + if seqlens_in_batch is None: + seqlens_in_batch = torch.sum(attention_mask, dim=1) + set_seqlens_in_batch(seqlens_in_batch) + + (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data( + inputs_embeds, attention_mask, position_ids, labels + ) + + outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + labels=labels, + **kwargs, + ) + + if self.training and getattr(self.config, "time_token_ids", []): + outputs.loss = soft_cross_entropy( + outputs.logits, + labels, + soft_tokens=self.config.time_token_ids, + std=self.config.soft_ce_std, + ) + + # Loss rescale for SP + if get_pg_manager() is not None: + loss_weight = calculate_loss_weight(labels) + outputs.loss = outputs.loss * loss_weight + + if dpo_forward: + return outputs.logits, labels + + return outputs + + +AutoConfig.register("llava_llama", LlavaLlamaConfig) +AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel) diff --git a/llava/model/language_model/qllama.py b/llava/model/language_model/qllama.py new file mode 100644 index 0000000000000000000000000000000000000000..db823b94634ec5cfeb589319e947f64a473f3a30 --- /dev/null +++ b/llava/model/language_model/qllama.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass +from ..qfunction import * + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QLlamaConfig" + + +class QLlamaConfig(LlamaConfig): + model_type = "qllama" + + +class QLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx + ) + self.up_proj = QLinearTE(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx) + self.down_proj = QLinearTE( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_idx=layer_idx + ) + + +class QLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + self.q_proj = QLinearTE( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.k_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.v_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.o_proj = QLinearTE( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + + +class QLlamaFlashAttention2(QLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + forward = LlamaFlashAttention2.forward + + +class QLlamaSdpaAttention(QLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + forward = LlamaSdpaAttention.forward + + +QLLAMA_ATTENTION_CLASSES = { + "eager": QLlamaAttention, + "flash_attention_2": QLlamaFlashAttention2, + "sdpa": QLlamaSdpaAttention, +} + + +class QLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QLlamaMLP(config, layer_idx) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + forward = LlamaDecoderLayer.forward + + +class QLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QLlamaModel(QLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QLlamaConfig + """ + + def __init__(self, config: QLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QLlamaForCausalLM(QLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QLlamaForSequenceClassification(QLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qllama", QLlamaConfig) +AutoModel.register(QLlamaConfig, QLlamaModel) +AutoModelForCausalLM.register(QLlamaConfig, QLlamaForCausalLM) diff --git a/llava/model/language_model/qllava_qllama.py b/llava/model/language_model/qllava_qllama.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe4f0785352cc1d9fef4b20afdfb6183672fc32 --- /dev/null +++ b/llava/model/language_model/qllava_qllama.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +import inspect +import os +import os.path as osp +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from transformers import ( + AutoConfig, + AutoModel, + GenerationConfig, + LlamaConfig, + LlamaForCausalLM, + PretrainedConfig, + PreTrainedModel, +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import ContextManagers, no_init_weights + +from ..configuration_llava import LlavaConfig +from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +from ..utils import get_model_config, get_model_config_fp8 +from .builder import build_llm_and_tokenizer +from .llava_llama import LlavaLlamaConfig, LlavaLlamaModel + +quantize_args_to_model_class = { + "fp8Linear_llama": "QLlamaForCausalLM", + "fp8LinearAndActivation_llama": "QMemLlamaForCausalLM", + "fp8Linear_qwen2": "FP8LinearQwen2ForCausalLM", + "fp8Activation_qwen2": "FP8ActivationQwen2ForCausalLM", + "fp8ActivationResidual_qwen2": "FP8ActivationResidualQwen2ForCausalLM", +} + + +class QLlavaLlamaConfig(LlavaLlamaConfig): + model_type = "qllava_qllama" + + +## FIXME we will follow the convention to add a new class for CausalLM in the future +class QLlavaLlamaModel(LlavaLlamaModel): + config_class = QLlavaLlamaConfig + main_input_name = "input_embeds" + supports_gradient_checkpointing = True + + def __init__(self, config: QLlavaLlamaConfig = None, model_args=None, *args, **kwargs) -> None: + PreTrainedModel.__init__(self, config) + return self.init_vlm(config=config, model_args=model_args, *args, **kwargs) + + # rewrite to support QLlama + def init_vlm(self, config: PreTrainedModel = None, model_args=None, *args, **kwargs): + # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. + if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"): + # already initialized, skipped + return + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]: + cfgs = get_model_config_fp8(config) # The first cfg is fp8 + else: + cfgs = get_model_config(config) + if len(cfgs) == 3: + llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs + elif len(cfgs) == 4: + llm_cfg, vision_tower_cfg, mm_projector_cfg, fp8_llm_cfg = cfgs + kwargs.update({"fp8_llm_cfg": fp8_llm_cfg}) + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") + + kwargs.update( + { + "quantize_model_class": quantize_args_to_model_class[model_args.quantize_model], + "model_args": model_args, + } + ) + self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + + + for name, module in self.llm.named_modules(): + module.layer_name = name + + self.pad_to_multiple_of = model_args.pad_to_multiple_of + + self.post_config() + self.is_loaded = True + + assert ( + self.llm is not None or self.vision_tower is not None or self.mm_projector is not None + ), "At least one of the components must be instantiated." + + +AutoConfig.register("qllava_qllama", QLlavaLlamaConfig) +AutoModel.register(QLlavaLlamaConfig, QLlavaLlamaModel) diff --git a/llava/model/language_model/qmemllama.py b/llava/model/language_model/qmemllama.py new file mode 100644 index 0000000000000000000000000000000000000000..26e794840df2230013a66f643745e9f50ffbeef9 --- /dev/null +++ b/llava/model/language_model/qmemllama.py @@ -0,0 +1,679 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass + +from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QMemLlamaConfig" + + +class QMemLlamaConfig(LlamaConfig): + model_type = "qmemllama" + + +class QLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.qargs = args + self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in") + self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out") + + def forward(self, hidden_states, s): + hidden_states = self.QAct_layernorm_in(hidden_states, s) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + + hidden_states, s = self.QAct_layernorm_out(hidden_states) + return hidden_states, s + + +ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm) + + +class QMemLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + self.gate_proj = QLinear( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate" + ) + self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up") + self.down_proj = QLinear( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down" + ) + self.act_fn = ACT2FN[config.hidden_act] + + self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum") + self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate") + + self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up") + + self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in") + self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out") + + self.QMul_act = QMul(config, layer_type="mul_act") + + def forward(self, x, s): + if self.config.pretraining_tp > 1: + raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity") + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + x = self.QAct_act_sum(x, s) + x_gate, s_gate = self.QAct_act_gate(x) + x_up, s_up = self.QAct_act_up(x) + x_gate, s_gate = self.gate_proj(x_gate, s_gate) + x_gate = self.QAct_act_in(x_gate, s_gate) + x_gate = self.act_fn(x_gate) + x_gate, s_gate = self.QAct_act_out(x_gate) + + x_up, s_up = self.up_proj(x_up, s_up) + x, s = self.QMul_act(x_gate, x_up, s_gate, s_up) + down_proj, s = self.down_proj(x, s) + return down_proj, s + + +class QMemLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + self.q_proj = QLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_q", + ) + self.k_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_k", + ) + self.v_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_v", + ) + self.o_proj = QLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_type="attn_proj", + ) + + self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum") + + self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in") + self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in") + self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in") + + self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out") + self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out") + self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out") + self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in") + + +class QMemLlamaFlashAttention2(QMemLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + if not output_attentions: + attn_weights = None + + return attn_output, s, attn_weights, past_key_value + + +class QMemLlamaSdpaAttention(QMemLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + # attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + return attn_output, s, None, past_key_value + + +QMemLLAMA_ATTENTION_CLASSES = { + "eager": QMemLlamaAttention, + "flash_attention_2": QMemLlamaFlashAttention2, + "sdpa": QMemLlamaSdpaAttention, +} + + +class QMemLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QMemLlamaMLP(config, layer_idx) + + self.input_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn" + ) + self.post_attention_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp" + ) + + self.QAdd_attn = QAdd(config, layer_type="add_attn") + self.QAdd_mlp = QAdd(config, layer_type="add_mlp") + + self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx") + self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re") + + self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx") + self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual, res = self.QAct_reattnout_re(hidden_states) + hidden_states, s = self.QAct_reattnout_fx(hidden_states) + + hidden_states, s = self.input_layernorm(hidden_states, s) + + # Self Attention + hidden_states, s, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + s=s, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.QAdd_attn(residual, hidden_states, res, s) + + # Fully Connected + residual, res = self.QAct_remlpout_re(hidden_states) + hidden_states, s = self.QAct_remlpout_fx(hidden_states) + + hidden_states, s = self.post_attention_layernorm(hidden_states, s) + hidden_states, s = self.mlp(hidden_states, s) + hidden_states = self.QAdd_mlp(residual, hidden_states, res, s) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QMemLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QMemLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QMemLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QMemLlamaModel(QMemLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QMemLlamaConfig + """ + + def __init__(self, config: QMemLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QMemLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QMemLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qmemllama", QMemLlamaConfig) +AutoModel.register(QMemLlamaConfig, QMemLlamaModel) +AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM) diff --git a/llava/model/language_model/realqmemllama.py b/llava/model/language_model/realqmemllama.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb71d015f8f7a0dcffc466b4093f67970537335 --- /dev/null +++ b/llava/model/language_model/realqmemllama.py @@ -0,0 +1,674 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass +from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QMemLlamaConfig" + + +class QMemLlamaConfig(LlamaConfig): + model_type = "qmemllama" + + +class QLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.qargs = args + self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in") + self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out") + + def forward(self, hidden_states, s): + hidden_states = self.QAct_layernorm_in(hidden_states, s) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + + hidden_states, s = self.QAct_layernorm_out(hidden_states) + return hidden_states, s + + +ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm) + + +class QMemLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + self.gate_proj = QLinear( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate" + ) + self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up") + self.down_proj = QLinear( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down" + ) + self.act_fn = ACT2FN[config.hidden_act] + + self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum") + self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate") + self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up") + + self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in") + self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out") + + self.QMul_act = QMul(config, layer_type="mul_act") + + def forward(self, x, s): + if self.config.pretraining_tp > 1: + raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity") + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + x = self.QAct_act_sum(x, s) + x_gate, s_gate = self.QAct_act_gate(x) + x_up, s_up = self.QAct_act_up(x) + x_gate, s_gate = self.gate_proj(x_gate, s_gate) + x_gate = self.QAct_act_in(x_gate, s_gate) + x_gate = self.act_fn(x_gate) + x_gate, s_gate = self.QAct_act_out(x_gate) + + x_up, s_up = self.up_proj(x_up, s_up) + x, s = self.QMul_act(x_gate, x_up, s_gate, s_up) + down_proj, s = self.down_proj(x, s) + return down_proj, s + + +class QMemLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + self.q_proj = QLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_q", + ) + self.k_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_k", + ) + self.v_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_v", + ) + self.o_proj = QLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_type="attn_proj", + ) + + self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum") + + self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in") + self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in") + self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in") + + self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out") + self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out") + self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out") + self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in") + + +class QMemLlamaFlashAttention2(QMemLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + if not output_attentions: + attn_weights = None + + return attn_output, s, attn_weights, past_key_value + + +class QMemLlamaSdpaAttention(QMemLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + # attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + return attn_output, s, None, past_key_value + + +QMemLLAMA_ATTENTION_CLASSES = { + "eager": QMemLlamaAttention, + "flash_attention_2": QMemLlamaFlashAttention2, + "sdpa": QMemLlamaSdpaAttention, +} + + +class QMemLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QMemLlamaMLP(config, layer_idx) + self.input_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn" + ) + self.post_attention_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp" + ) + + self.QAdd_attn = QAdd(config, layer_type="add_attn") + self.QAdd_mlp = QAdd(config, layer_type="add_mlp") + + self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx") + self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re") + + self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx") + self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual, res = self.QAct_reattnout_re(hidden_states) + hidden_states, s = self.QAct_reattnout_fx(hidden_states) + + hidden_states, s = self.input_layernorm(hidden_states, s) + + # Self Attention + hidden_states, s, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + s=s, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.QAdd_attn(residual, hidden_states, res, s) + + # Fully Connected + residual, res = self.QAct_remlpout_re(hidden_states) + hidden_states, s = self.QAct_remlpout_fx(hidden_states) + + hidden_states, s = self.post_attention_layernorm(hidden_states, s) + hidden_states, s = self.mlp(hidden_states, s) + hidden_states = self.QAdd_mlp(residual, hidden_states, res, s) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QMemLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QMemLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QMemLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QMemLlamaModel(QMemLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QMemLlamaConfig + """ + + def __init__(self, config: QMemLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QMemLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QMemLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qmemllama", QMemLlamaConfig) +AutoModel.register(QMemLlamaConfig, QMemLlamaModel) +AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM) diff --git a/llava/model/liger/cross_entropy.py b/llava/model/liger/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..06274f332b62ac2178904439a18c3474693242a2 --- /dev/null +++ b/llava/model/liger/cross_entropy.py @@ -0,0 +1,417 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import operator +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import compare_version, element_mul_kernel, is_hip + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + +_TRUE = tl.constexpr(1) +_FALSE = tl.constexpr(0) + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + BLOCK_SIZE (int): The block size for Triton operations. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + z_loss_ptr += program_id * loss_stride + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i โ‰  y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / (n_non_ignore) + # chain rule + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + loss += z_loss + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS == _TRUE: + tl.store(z_loss_ptr, z_loss) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +_bool_to_return_z_loss = { + True: _TRUE.value, + False: _FALSE.value, +} + + +def cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + if not isinstance(return_z_loss, int): + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" + return_z_loss = _bool_to_return_z_loss[return_z_loss] + else: + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + if return_z_loss == _TRUE.value: + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + else: + z_loss_1d = loss_1d # dummy ptr when return_z_loss == False + + n_non_ignore = (target != ignore_index).sum().item() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap if softcap is not None else 0.0, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_SOFTCAPPING=True if softcap is not None else False, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + loss = torch.sum(loss_1d) + if return_z_loss == _TRUE.value: + z_loss = torch.sum(z_loss_1d) + else: + z_loss = None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + ) + + +def liger_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): + reduction = "sum" if num_items_in_batch is not None else "mean" + # loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss, _ = LigerCrossEntropyFunction.apply(source, target, ignore_index, 0.0, 0.0, reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss + + +def LigerForCausalLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = liger_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) + return loss diff --git a/llava/model/liger/utils.py b/llava/model/liger/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5290b98c237f43757d0ed74d91f8f3bf6a807eeb --- /dev/null +++ b/llava/model/liger/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator +from typing import Callable + +import torch +import triton +import triton.language as tl +from packaging.version import Version + + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds " + f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type="cuda"), + functools.partial(torch.amp.custom_bwd, device_type="cuda"), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..f6feb1014fbea1a6d7904b88c5d2f2a0d84607bb --- /dev/null +++ b/llava/model/llava_arch.py @@ -0,0 +1,861 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import os +import os.path as osp +import warnings +from abc import ABC +from collections import OrderedDict, defaultdict, deque +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange +from hydra.utils import instantiate +from transformers import AutoConfig, GenerationConfig, LogitsProcessor, PreTrainedModel +from transformers.modeling_utils import ContextManagers, no_init_weights + +from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS +from llava.mm_utils import process_image, process_images, process_sounds,process_sound_masks +from llava.model.configuration_llava import LlavaConfig, ResponseFormat +from llava.model.language_model.builder import build_llm_and_tokenizer +from llava.model.multimodal_encoder.builder import build_sound_tower +from llava.model.multimodal_projector.builder import build_speech_mm_projector, build_sound_mm_projector +from llava.model.utils import get_model_config +from llava.train.sequence_parallel import get_pg_manager +from llava.utils import distributed +from llava.utils.media import extract_media +from llava.utils.tokenizer import tokenize_conversation + + +class LlavaMetaModel(ABC): + def _init_llm(self, llm_cfg, config, *args, **kwargs): + llm, tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + return llm, tokenizer + + def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): + # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. + if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "speech_tower") or hasattr(self, "sound_tower") or hasattr(self, "mm_projector") or hasattr(self, "speech_mm_projector") or hasattr(self, "sound_mm_projector"): + # already initialized, skipped + return + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + cfgs = get_model_config(config) + print(cfgs) + if len(cfgs) == 7: + llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") + + self.llm, self.tokenizer = self._init_llm(llm_cfg, config, *args, **kwargs) + + self.sound_tower = build_sound_tower(sound_tower_cfg, config) + self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) + + if isinstance(self.config, dict): + self.vocab_size = config.llm_cfg["vocab_size"] + NUM_EXTRA_TOKENS + else: + self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS + logging.info( + f"[XGrammar] config is not a dict, loading vocab size from tokenizer {self.tokenizer.vocab_size} + {NUM_EXTRA_TOKENS} => {self.vocab_size}" + ) + + # XGrammar tokenizer and grammar compiler + # lazy init only when specified json output during inference + self.grammar_compiler = None + + self.encoders = {} + for name in ["sound"]: + config = getattr(self.config, f"{name}_encoder") + if isinstance(config, str): + config = json.loads(config) + self.encoders[name] = instantiate(config, parent=self) + + self.post_config() + self.is_loaded = True + + assert ( + self.llm is not None or self.vision_tower is not None or self.speech_tower is not None or self.mm_projector is not None or self.speech_mm_projector is not None + ), "At least one of the components must be instantiated." + + @classmethod + def load_from_config(cls, model_path_or_config, *args, **kwargs): + pass + + ## FIXME we will use this function to load model in the future + @classmethod + def load_pretrained(cls, model_path_or_config, *args, **kwargs): + kwargs.pop("config", None) + + if isinstance(model_path_or_config, str): + config = AutoConfig.from_pretrained(model_path_or_config) + elif isinstance(model_path_or_config, LlavaConfig): + config = model_path_or_config + else: + raise NotImplementedError( + f"wrong type, {type(model_path_or_config)} \ + {isinstance(model_path_or_config, LlavaConfig)}" + ) + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + cfgs = get_model_config(config) + if len(cfgs) == 7: + llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") + + init_context = [ + no_init_weights(_enable=True), + ] + + with ContextManagers(init_context): + vlm = cls(config, *args, **kwargs) + + if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "speech_tower") or hasattr(vlm, "sound_tower") or hasattr(vlm, "mm_projector") or hasattr(vlm, "speech_mm_projector") or hasattr(vlm, "sound_mm_projector"): + if vlm.is_loaded: + return vlm + + vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + vlm.sound_tower = build_sound_tower(sound_tower_cfg, config) + vlm.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) + + self.post_config() + self.is_loaded = True + + # FIXME(ligeng, yunhao): llm should never be none here. + assert ( + vlm.llm is not None or vlm.vision_tower is not None or vlm.speech_tower is not None or vlm.mm_projector is not None or vlm.speech_mm_projector is not None + ), "At least one of the components must be instantiated." + return vlm + + ## FIXME we will use this function to save the model in the future + def save_pretrained(self, output_dir, state_dict=None): + if state_dict is None: + # other wise fetch from deepspeed + # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) + state_dict = self.state_dict() + + if getattr(self, "tokenizer", None): + self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) + + if self.get_llm(): + print(f"saving llm to {osp.join(output_dir, 'llm')}") + self.llm.config._name_or_path = osp.join(output_dir, "llm") + llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) + self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) + self.config.llm_cfg = self.llm.config + + + if self.get_sound_tower(): + print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}") + self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower") + sound_tower_state_dict = OrderedDict( + {k.split("sound_tower.sound_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k} + ) + self.sound_tower.sound_tower.save_pretrained( + os.path.join(output_dir, "sound_tower"), + state_dict=sound_tower_state_dict, + ) + self.config.sound_tower_cfg = self.sound_tower.config + + if self.get_sound_mm_projector(): + print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}") + self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector") + sound_mm_projector_state_dict = OrderedDict( + {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k} + ) + self.sound_mm_projector.save_pretrained( + os.path.join(output_dir, "sound_mm_projector"), + state_dict=sound_mm_projector_state_dict, + ) + self.config.sound_mm_projector_cfg = self.sound_mm_projector.config + + ## update and save top-level config + self.config._name_or_path = output_dir + self.config.architectures = [self.__class__.__name__] + self.config.save_pretrained(output_dir) + + def get_llm(self): + llm = getattr(self, "llm", None) + if type(llm) is list: + llm = llm[0] + return llm + + def get_lm_head(self): + lm_head = getattr(self.get_llm(), "lm_head", None) + return lm_head + + def get_sound_tower(self): + sound_tower = getattr(self, "sound_tower", None) + if type(sound_tower) is list: + sound_tower = sound_tower[0] + return sound_tower + + + def get_sound_mm_projector(self): + sound_mm_projector = getattr(self, "sound_mm_projector", None) + if type(sound_mm_projector) is list: + sound_mm_projector = sound_mm_projector[0] + return sound_mm_projector + + def post_config(self): + self.training = self.get_llm().training + ## configuration + if getattr(self.config, "llm_cfg", None) is None: + self.config.llm_cfg = self.llm.config + self.config.speech_tower_cfg = self.speech_tower.config + if getattr(self.config, "sound_tower_cfg", None) is None: + self.config.sound_tower_cfg = self.sound_tower.config + if getattr(self.config, "sound_mm_projector_cfg", None) is None: + self.config.sound_mm_projector_cfg = self.sound_mm_projector.config + + def freezed_module_patch(self): + """ + Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. + """ + if self.training: + if self.get_llm() and not getattr(self.config, "tune_language_model", False): + pass + + if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False): + self.get_sound_tower().eval() + if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False): + self.get_sound_mm_projector().eval() + + + def encode_sound(self, sounds, masks=None): + + sound_features = self.get_sound_tower()(sounds, masks) + sound_features = self.get_sound_mm_projector()(sound_features) + return sound_features + + ## @yunhao: is there a better way to handle function call and attributes for llm? + ## support beam search + def _temporary_reorder_cache(self, past_key_values, sorted_idx): + return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) + + def get_input_embeddings(self): + return self.get_llm().get_input_embeddings() + + def get_output_embeddings(self): + return self.get_llm().get_output_embeddings() + + def resize_token_embeddings(self, embed_size): + self.get_llm().resize_token_embeddings(embed_size) + + +class LlavaMetaForCausalLM(ABC): + def _embed( + self, + input_ids: torch.Tensor, + media: Dict[str, List[torch.Tensor]], + media_config: Dict[str, Dict[str, Any]], + labels: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + media_meta: Dict[str, Dict[str, Any]]= None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) + attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + for name in media: + self.encoders[name].end_tokens = None + + # Extract text and media embeddings + text_embeds = self.llm.model.embed_tokens(input_ids) + media_embeds = self.__embed_media_tokens(media, media_config, media_meta) + # This is a workaround to make sure the dummy embeddings are consumed + while media_embeds.get("dummy"): + dummy_embed = media_embeds["dummy"].popleft() + text_embeds += torch.sum(dummy_embed) * 0 + # Remove padding + batch_size = labels.shape[0] + + # Build inverse mapping from token ID to media name + media_tokens = {} + for name, token_id in self.tokenizer.media_token_ids.items(): + media_tokens[token_id] = name + + # -------------------------------- # + num_audio_tokens = torch.stack(media_meta["sound_embed_masks"], dim=0).sum(-1) + num_audio_tokens = torch.tensor([round(int(x) / 10) * 10 for x in num_audio_tokens]) + num_audios = len(media_embeds['sound']) # length of queue is the number of audios we have in total + max_audio_tokens, embed_dim = media_embeds['sound'][0].shape + + + audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + num_audio_tokens.device + ) < num_audio_tokens.unsqueeze(1) + + audio_embeds = [] + while media_embeds['sound']: + audio_embeds.append(media_embeds['sound'].popleft()) + audio_embeds = torch.stack(audio_embeds,dim=0) + + masked_audio_features = audio_embeds[audio_features_mask].view(-1, embed_dim) + batch_size, sequence_length = input_ids.shape + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = True + if batch_size > 1: + if _left_padding and not _right_padding: + left_padding = True + elif not _left_padding and _right_padding: + left_padding = False + elif not _left_padding and not _right_padding: + # both side is 1, so cannot tell + left_padding = self.tokenizer.padding_side == "left" + else: + # invalid attention_mask + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + + # 1. Create a mask to know where special audio tokens are + special_audio_token_mask = input_ids == self.tokenizer.media_token_ids['sound'] #hard coded to just work with 'sound' + num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) + + # In case the Audio model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = text_embeds.device + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + num_audio_tokens = num_audio_tokens.to(target_device) + batch_indices, non_audio_indices = torch.where( + (input_ids != self.tokenizer.media_token_ids['sound']) & (attention_mask == 1) + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged audio-text sequence. + # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. + # `torch.cumsum` computes how each audio token shifts subsequent text token positions. + token_placeholder_num = torch.zeros_like(input_ids) + token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 + token_placeholder_num = token_placeholder_num + 1 + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = token_placeholder_num.sum(-1).max() + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + batch_indices, non_audio_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_audio_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_token_num, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_token_num, dtype=attention_mask.dtype, device=text_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_token_num), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=text_embeds.device + ) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "