SreyanG-NVIDIA commited on
Commit
174ae06
·
verified ·
1 Parent(s): 004a685

Upload 225 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +20 -0
  2. README.md +139 -14
  3. llava/__init__.py +8 -0
  4. llava/cli/infer_audio.py +88 -0
  5. llava/constants.py +55 -0
  6. llava/conversation.py +197 -0
  7. llava/data/__init__.py +9 -0
  8. llava/data/base.py +95 -0
  9. llava/data/builder.py +193 -0
  10. llava/data/collate.py +166 -0
  11. llava/data/dataset.py +1635 -0
  12. llava/data/datasets_mixture.py +80 -0
  13. llava/data/registry/datasets/audio_test.yaml +97 -0
  14. llava/data/registry/datasets/default.yaml +5 -0
  15. llava/data/registry/mixtures.yaml +78 -0
  16. llava/entry.py +60 -0
  17. llava/eval/__init__.py +15 -0
  18. llava/eval/eval_audio_bench.py +117 -0
  19. llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc +0 -0
  20. llava/eval/mmmu_utils/eval_utils.py +61 -0
  21. llava/eval/registry_audio.yaml +93 -0
  22. llava/media.py +47 -0
  23. llava/mm_utils.py +641 -0
  24. llava/model/FloatPointQuantizeTorch.py +85 -0
  25. llava/model/FloatPointQuantizeTriton.py +199 -0
  26. llava/model/__init__.py +35 -0
  27. llava/model/apply_delta.py +77 -0
  28. llava/model/builder.py +161 -0
  29. llava/model/coat/activation/__init__.py +6 -0
  30. llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py +101 -0
  31. llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py +181 -0
  32. llava/model/coat/activation/fake_quantization/quantize_function.py +239 -0
  33. llava/model/coat/activation/fake_quantization/utils.py +115 -0
  34. llava/model/coat/activation/models/_fp8_quantization_config.py +67 -0
  35. llava/model/coat/activation/models/_fp8_weightcache.py +48 -0
  36. llava/model/coat/activation/models/_fp8manager.py +31 -0
  37. llava/model/coat/activation/models/coat_llama.py +1479 -0
  38. llava/model/coat/activation/models/coat_llama_convert_from_hf.py +71 -0
  39. llava/model/coat/activation/models/coat_olmo.py +1942 -0
  40. llava/model/coat/activation/real_quantization/__init__.py +31 -0
  41. llava/model/coat/activation/real_quantization/_dequantize.py +162 -0
  42. llava/model/coat/activation/real_quantization/_division.py +212 -0
  43. llava/model/coat/activation/real_quantization/_division_transpose.py +215 -0
  44. llava/model/coat/activation/real_quantization/_memory_io.py +180 -0
  45. llava/model/coat/activation/real_quantization/_quantize.py +176 -0
  46. llava/model/coat/activation/real_quantization/_quantize_pertensor.py +152 -0
  47. llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py +155 -0
  48. llava/model/coat/activation/real_quantization/_transpose.py +121 -0
  49. llava/model/coat/activation/real_quantization/add_bwd.py +205 -0
  50. llava/model/coat/activation/real_quantization/add_fwd.py +219 -0
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o filter=lfs diff=lfs merge=lfs -text
38
+ llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o filter=lfs diff=lfs merge=lfs -text
39
+ llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o filter=lfs diff=lfs merge=lfs -text
40
+ static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
41
+ static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
42
+ static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
43
+ static/audio/audio2.wav filter=lfs diff=lfs merge=lfs -text
44
+ static/chat/audio1.mp3 filter=lfs diff=lfs merge=lfs -text
45
+ static/chat/audio2.mp3 filter=lfs diff=lfs merge=lfs -text
46
+ static/emergent/audio1.wav filter=lfs diff=lfs merge=lfs -text
47
+ static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
48
+ static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav filter=lfs diff=lfs merge=lfs -text
49
+ static/speech/audio3.wav filter=lfs diff=lfs merge=lfs -text
50
+ static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav filter=lfs diff=lfs merge=lfs -text
51
+ static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav filter=lfs diff=lfs merge=lfs -text
52
+ static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav filter=lfs diff=lfs merge=lfs -text
53
+ static/think/audio1.wav filter=lfs diff=lfs merge=lfs -text
54
+ static/think/audio2.wav filter=lfs diff=lfs merge=lfs -text
55
+ static/voice/voice_2.mp3 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,139 @@
1
- ---
2
- title: Audio Flamingo 3
3
- emoji:
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.36.2
8
- app_file: app.py
9
- pinned: false
10
- license: other
11
- short_description: Online demo for Audio Flamingo 3
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div align="center" style="display: flex; justify-content: center; align-items: center; text-align: center;">
3
+ <a href="https://github.com/NVIDIA/audio-flamingo" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
4
+ <img src="static/logo-no-bg.png" alt="Audio Flamingo 3 🔥🚀🔥" width="120">
5
+ </a>
6
+ </div>
7
+ <div align="center" style="display: flex; justify-content: center; align-items: center; text-align: center;">
8
+ <h2>
9
+ Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models
10
+ </h2>
11
+ </div>
12
+
13
+ <div align="center" style="display: flex; justify-content: center; margin-top: 10px;">
14
+ <a href=""><img src="https://img.shields.io/badge/arXiv-2503.03983-AD1C18" style="margin-right: 5px;"></a>
15
+ <a href="https://research.nvidia.com/labs/adlr/AF3/"><img src="https://img.shields.io/badge/Demo page-228B22" style="margin-right: 5px;"></a>
16
+ <a href="https://github.com/NVIDIA/audio-flamingo"><img src='https://img.shields.io/badge/Github-Audio Flamingo 3-9C276A' style="margin-right: 5px;"></a>
17
+ <a href="https://github.com/NVIDIA/audio-flamingo/stargazers"><img src="https://img.shields.io/github/stars/NVIDIA/audio-flamingo.svg?style=social"></a>
18
+ </div>
19
+
20
+ <div align="center" style="display: flex; justify-content: center; margin-top: 10px; flex-wrap: wrap; gap: 5px;">
21
+ <a href="https://huggingface.co/nvidia/audio-flamingo-3">
22
+ <img src="https://img.shields.io/badge/🤗-Checkpoints-ED5A22.svg">
23
+ </a>
24
+ <a href="https://huggingface.co/nvidia/audio-flamingo-3-chat">
25
+ <img src="https://img.shields.io/badge/🤗-Checkpoints (Chat)-ED5A22.svg">
26
+ </a>
27
+ <a href="https://huggingface.co/datasets/nvidia/AudioSkills">
28
+ <img src="https://img.shields.io/badge/🤗-Dataset: AudioSkills--XL-ED5A22.svg">
29
+ </a>
30
+ <a href="https://huggingface.co/datasets/nvidia/LongAudio">
31
+ <img src="https://img.shields.io/badge/🤗-Dataset: LongAudio--XL-ED5A22.svg">
32
+ </a>
33
+ <a href="https://huggingface.co/datasets/nvidia/AF-Chat">
34
+ <img src="https://img.shields.io/badge/🤗-Dataset: AF--Chat-ED5A22.svg">
35
+ </a>
36
+ <a href="https://huggingface.co/datasets/nvidia/AF-Think">
37
+ <img src="https://img.shields.io/badge/🤗-Dataset: AF--Think-ED5A22.svg">
38
+ </a>
39
+ </div>
40
+
41
+ <div align="center" style="display: flex; justify-content: center; margin-top: 10px;">
42
+ <a href="https://huggingface.co/spaces/nvidia/audio_flamingo_3"><img src="https://img.shields.io/badge/🤗-Gradio Demo (7B)-5F9EA0.svg" style="margin-right: 5px;"></a>
43
+ </div>
44
+
45
+ ## Overview
46
+
47
+ This repo contains the PyTorch implementation of [Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models](). Audio Flamingo 3 (AF3) is a fully open, state-of-the-art Large Audio-Language Model (LALM) that advances reasoning and understanding across speech, sounds, and music. AF3 builds on previous work with innovations in:
48
+
49
+ - Unified audio representation learning (speech, sound, music)
50
+ - Flexible, on-demand chain-of-thought reasoning (Thinking in Audio)
51
+ - Long-context audio comprehension (including speech and up to 10 minutes)
52
+ - Multi-turn, multi-audio conversational dialogue (AF3-Chat)
53
+ - Voice-to-voice interaction (AF3-Chat)
54
+
55
+ Extensive evaluations confirm AF3’s effectiveness, setting new benchmarks on over 20 public audio understanding and reasoning tasks.
56
+
57
+
58
+ ## Main Results
59
+
60
+ Audio Flamingo 3 outperforms prior SOTA models including GAMA, Audio Flamingo, Audio Flamingo 2, Qwen-Audio, Qwen2-Audio, Qwen2.5-Omni.LTU, LTU-AS, SALMONN, AudioGPT, Gemini Flash v2 and Gemini Pro v1.5 on a number of understanding and reasoning benchmarks.
61
+
62
+ <div align="center">
63
+ <img class="img-full" src="static/af3_radial-1.png" width="300">
64
+ </div>
65
+
66
+ <div align="center">
67
+ <img class="img-full" src="static/af3_sota.png" width="400">
68
+ </div>
69
+
70
+ ## Audio Flamingo 3 Architecture
71
+
72
+ Audio Flamingo 3 uses AF-Whisper unified audio encoder, MLP-based audio adaptor, Decoder-only LLM backbone (Qwen2.5-7B), and Streaming TTS module (AF3-Chat).
73
+ Audio Flamingo 3 can take up to 10 minutes of audio inputs.
74
+
75
+ <div align="center">
76
+ <img class="img-full" src="static/af3_main_diagram-1.png" width="800">
77
+ </div>
78
+
79
+ ## Installation
80
+
81
+ ```bash
82
+ ./environment_setup.sh af3
83
+ ```
84
+
85
+ ## Code Structure
86
+
87
+ - The folder ```audio_flamingo_3/``` contains the main training and inference code of Audio Flamingo 3.
88
+ - The folder ```audio_flamingo_3/scripts``` contains the inference scripts of Audio Flamingo 3 in case you would like to use our pretrained checkpoints on HuggingFace.
89
+
90
+ Each folder is self-contained and we expect no cross dependencies between these folders. This repo does not contain the code for Streaming-TTS pipeline which will released in the near future.
91
+
92
+ ## Single Line Inference
93
+
94
+ To infer stage 3 model directly, run the command below:
95
+ ```bash
96
+ python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav
97
+ ```
98
+
99
+ To infer the model in stage 3.5 model, run the command below:
100
+ ```bash
101
+ python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --model-path /path/to/checkpoint/af3-7b/stage35 --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav --peft-mode
102
+ ```
103
+
104
+ ## References
105
+
106
+ The main training and inferencing code within each folder are modified from [NVILA](https://github.com/NVlabs/VILA/tree/main) [Apache license](incl_licenses/License_1.md).
107
+
108
+ ## License
109
+
110
+ - The code in this repo is under [MIT license](incl_licenses/MIT_license.md).
111
+ - The checkpoints are for non-commercial use only [NVIDIA OneWay Noncommercial License](incl_licenses/NVIDIA_OneWay_Noncommercial_License.docx). They are also subject to the [Qwen Research license](https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/LICENSE), the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset.
112
+ - Notice: Audio Flamingo 3 is built with Qwen-2.5. Qwen is licensed under the Qwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
113
+
114
+
115
+ ## Citation
116
+
117
+ - Audio Flamingo 2
118
+ ```
119
+ @article{ghosh2025audio,
120
+ title={Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities},
121
+ author={Ghosh, Sreyan and Kong, Zhifeng and Kumar, Sonal and Sakshi, S and Kim, Jaehyeon and Ping, Wei and Valle, Rafael and Manocha, Dinesh and Catanzaro, Bryan},
122
+ journal={arXiv preprint arXiv:2503.03983},
123
+ year={2025}
124
+ }
125
+ ```
126
+
127
+ - Audio Flamingo
128
+ ```
129
+ @inproceedings{kong2024audio,
130
+ title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities},
131
+ author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan},
132
+ booktitle={International Conference on Machine Learning},
133
+ pages={25125--25148},
134
+ year={2024},
135
+ organization={PMLR}
136
+ }
137
+ ```
138
+
139
+
llava/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ from .entry import *
8
+ from .media import *
llava/cli/infer_audio.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import argparse
8
+ import importlib.util
9
+ import json
10
+ import os
11
+
12
+ from pydantic import BaseModel
13
+ from termcolor import colored
14
+
15
+ import llava
16
+ from llava import conversation as clib
17
+ from llava.media import Image, Video, Sound
18
+ from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat
19
+ from peft import PeftModel
20
+ import torch
21
+
22
+ def get_schema_from_python_path(path: str) -> str:
23
+ schema_path = os.path.abspath(path)
24
+ spec = importlib.util.spec_from_file_location("schema_module", schema_path)
25
+ schema_module = importlib.util.module_from_spec(spec)
26
+ spec.loader.exec_module(schema_module)
27
+
28
+ # Get the Main class from the loaded module
29
+ Main = schema_module.Main
30
+ assert issubclass(
31
+ Main, BaseModel
32
+ ), f"The provided python file {path} does not contain a class Main that describes a JSON schema"
33
+ return Main.schema_json()
34
+
35
+
36
+ def main() -> None:
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--model-base", "-mb", type=str, required=True)
39
+ parser.add_argument("--model-path", "-mp", type=str, required=True)
40
+ parser.add_argument("--conv-mode", "-c", type=str, default="auto")
41
+ parser.add_argument("--text", type=str)
42
+ parser.add_argument("--media", type=str, nargs="+")
43
+ parser.add_argument("--json-mode", action="store_true")
44
+ parser.add_argument("--peft-mode", action="store_true")
45
+ parser.add_argument("--json-schema", type=str, default=None)
46
+ args = parser.parse_args()
47
+
48
+ # Convert json mode to response format
49
+ if not args.json_mode:
50
+ response_format = None
51
+ elif args.json_schema is None:
52
+ response_format = ResponseFormat(type="json_object")
53
+ else:
54
+ schema_str = get_schema_from_python_path(args.json_schema)
55
+ print(schema_str)
56
+ response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str))
57
+
58
+ # Load model
59
+ model = llava.load(args.model_base)
60
+ if args.peft_mode:
61
+ model = PeftModel.from_pretrained(
62
+ model,
63
+ args.model_path,
64
+ device_map="auto",
65
+ torch_dtype=torch.float16,
66
+ )
67
+ # Set conversation mode
68
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
69
+
70
+ # Prepare multi-modal prompt
71
+ prompt = []
72
+ if args.media is not None:
73
+ for media in args.media or []:
74
+ if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]):
75
+ media = Sound(media)
76
+ else:
77
+ raise ValueError(f"Unsupported media type: {media}")
78
+ prompt.append(media)
79
+ if args.text is not None:
80
+ prompt.append(args.text)
81
+
82
+ # Generate response
83
+ response = model.generate_content(prompt, response_format=response_format)
84
+ print(colored(response, "cyan", attrs=["bold"]))
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
llava/constants.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
24
+
25
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
26
+ WORKER_HEART_BEAT_INTERVAL = 15
27
+
28
+ LOGDIR = "."
29
+
30
+ # Model Constants
31
+ IGNORE_INDEX = -100
32
+ DEFAULT_SOUND_TOKEN = "<sound>"
33
+ DEFAULT_SPEECH_TOKEN = "<speech>"
34
+ SENTINEL_TOKEN = "<vila/sentinel>"
35
+
36
+ MEDIA_TOKENS = {
37
+ "speech": "<speech>",
38
+ "sound": "<sound>",
39
+ }
40
+
41
+
42
+ """
43
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
44
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
45
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
46
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
47
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
48
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
49
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
50
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
51
+ 151651: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
52
+ 151652: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
53
+
54
+ """
55
+ NUM_EXTRA_TOKENS = 10
llava/conversation.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
23
+
24
+ import dataclasses
25
+ from enum import Enum, auto
26
+ from typing import List
27
+
28
+ from llava.utils.logging import logger
29
+
30
+
31
+ class SeparatorStyle(Enum):
32
+ """Different separator style."""
33
+
34
+ AUTO = auto()
35
+ TWO = auto()
36
+ MPT = auto()
37
+ PLAIN = auto()
38
+ LLAMA_3 = auto()
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class Conversation:
43
+ """A class that keeps all conversation history."""
44
+
45
+ system: str
46
+ roles: List[str]
47
+ messages: List[List[str]]
48
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
49
+ sep: str = "###"
50
+ sep2: str = None
51
+ version: str = "Unknown"
52
+
53
+ def get_prompt(self):
54
+ messages = self.messages
55
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
56
+ messages = self.messages.copy()
57
+ init_role, init_msg = messages[0].copy()
58
+ init_msg = init_msg[0].replace("<image>", "").strip()
59
+ messages[0] = (init_role, "<image>\n" + init_msg)
60
+
61
+ if self.sep_style == SeparatorStyle.TWO:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system + seps[0]
64
+ for i, (role, message) in enumerate(messages):
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + ": " + message + seps[i % 2]
69
+ else:
70
+ ret += role + ":"
71
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
72
+ ret = self.system + self.sep
73
+ for rid, (role, message) in enumerate(messages):
74
+ if message:
75
+ if type(message) is tuple:
76
+ message = message[0]
77
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
78
+ ret += role + message + sep
79
+ else:
80
+ ret += role
81
+ elif self.sep_style == SeparatorStyle.MPT:
82
+ ret = self.system + self.sep
83
+ for role, message in messages:
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, _, _ = message
87
+ ret += role + message + self.sep
88
+ else:
89
+ ret += role
90
+ elif self.sep_style == SeparatorStyle.PLAIN:
91
+ seps = [self.sep, self.sep2]
92
+ ret = self.system
93
+ for i, (role, message) in enumerate(messages):
94
+ if message:
95
+ if type(message) is tuple:
96
+ message, _, _ = message
97
+ ret += message + seps[i % 2]
98
+ else:
99
+ ret += ""
100
+ else:
101
+ raise ValueError(f"Invalid style: {self.sep_style}")
102
+
103
+ return ret
104
+
105
+ def append_message(self, role, message):
106
+ self.messages.append([role, message])
107
+
108
+ def copy(self):
109
+ return Conversation(
110
+ system=self.system,
111
+ roles=self.roles,
112
+ messages=[[x, y] for x, y in self.messages],
113
+ sep_style=self.sep_style,
114
+ sep=self.sep,
115
+ sep2=self.sep2,
116
+ version=self.version,
117
+ )
118
+
119
+
120
+ conv_auto = Conversation(
121
+ system="",
122
+ roles=("", ""),
123
+ messages=(),
124
+ sep_style=SeparatorStyle.AUTO,
125
+ sep="\n",
126
+ )
127
+
128
+ conv_vicuna_v1 = Conversation(
129
+ system="A chat between a curious user and an artificial intelligence assistant. "
130
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
131
+ roles=("USER", "ASSISTANT"),
132
+ version="v1",
133
+ messages=(),
134
+ sep_style=SeparatorStyle.TWO,
135
+ sep=" ",
136
+ sep2="</s>",
137
+ )
138
+
139
+ conv_llava_plain = Conversation(
140
+ system="",
141
+ roles=("", ""),
142
+ messages=(),
143
+ sep_style=SeparatorStyle.PLAIN,
144
+ sep="\n",
145
+ )
146
+
147
+ hermes_2 = Conversation(
148
+ system="<|im_start|>system\nAnswer the questions.",
149
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
150
+ sep_style=SeparatorStyle.MPT,
151
+ sep="<|im_end|>",
152
+ messages=(),
153
+ version="hermes-2",
154
+ )
155
+
156
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
157
+ llama_3_chat = Conversation(
158
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
159
+ "You are able to understand the visual content that the user provides, "
160
+ "and assist the user with a variety of tasks using natural language.",
161
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
162
+ version="llama_v3",
163
+ messages=(),
164
+ sep_style=SeparatorStyle.LLAMA_3,
165
+ sep="<|eot_id|>",
166
+ sep2="<|end_of_text|>",
167
+ )
168
+
169
+
170
+ default_conversation = conv_auto
171
+ conv_templates = {
172
+ "auto": conv_auto,
173
+ "hermes-2": hermes_2,
174
+ "llama_3": llama_3_chat,
175
+ "v1": conv_vicuna_v1,
176
+ "vicuna_v1": conv_vicuna_v1,
177
+ "plain": conv_llava_plain,
178
+ }
179
+
180
+
181
+ CONVERSATION_MODE_MAPPING = {
182
+ "vila1.5-3b": "vicuna_v1",
183
+ "vila1.5-8b": "llama_3",
184
+ "vila1.5-13b": "vicuna_v1",
185
+ "vila1.5-40b": "hermes-2",
186
+ "llama-3": "llama_3",
187
+ "llama3": "llama_3",
188
+ }
189
+
190
+
191
+ def auto_set_conversation_mode(model_name_or_path: str) -> str:
192
+ global default_conversation
193
+ for k, v in CONVERSATION_MODE_MAPPING.items():
194
+ if k in model_name_or_path.lower():
195
+ logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
196
+ default_conversation = conv_templates[v]
197
+ return
llava/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ from .builder import *
8
+ from .dataset import *
9
+ from .datasets_mixture import *
llava/data/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import random
8
+ from typing import Any, Dict, List
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ from transformers import PreTrainedTokenizer
13
+
14
+ from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images
15
+ from llava.train.args import DataArguments
16
+ from llava.utils.logging import logger
17
+ from llava.utils.media import extract_media
18
+ from llava.utils.tokenizer import preprocess_conversation
19
+
20
+ __all__ = ["BaseDataset"]
21
+
22
+ def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor:
23
+ return torch.tensor(speech)
24
+
25
+ def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor:
26
+ return torch.tensor(sound)
27
+
28
+ def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor:
29
+ return torch.tensor(sound_masks)
30
+
31
+
32
+ class BaseDataset(Dataset):
33
+ def __init__(
34
+ self,
35
+ tokenizer: PreTrainedTokenizer,
36
+ data_args: DataArguments,
37
+ no_system_prompt: bool = False,
38
+ **kwargs: Any,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.tokenizer = tokenizer
42
+ self.data_args = data_args
43
+ self.no_system_prompt = no_system_prompt
44
+ self.instances = []
45
+ self.enable_dynamic_res = False
46
+ self.enable_dynamic_res_s2 = False
47
+ # global_batch_size: int,
48
+ self.global_batch_size = kwargs.get("global_batch_size", 1)
49
+
50
+ # by default, dataset cls will resample on failure
51
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
52
+
53
+ # by default, dataset cls will resample on failure
54
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
55
+
56
+ def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]:
57
+ raise NotImplementedError
58
+
59
+ def __getitem__(self, index: int) -> Dict[str, Any]:
60
+ instance = self.instances[index]
61
+
62
+ try:
63
+ # Process instance to conversation
64
+ conversation = self.process(instance)
65
+
66
+ # Extract media from conversation
67
+ media, media_meta = extract_media(conversation, self.data_args)
68
+
69
+ if "speech" in media:
70
+ processed_speech = _process_speech(media["speech"], self.data_args)
71
+ if "sound" in media:
72
+ processed_sound = _process_sound(media["sound"], self.data_args)
73
+ processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args)
74
+ processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args)
75
+ # Prepare "input_ids" and "labels" for training
76
+ data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt)
77
+
78
+ if "speech" in media:
79
+ data["speech"] = processed_speech
80
+ if "sound" in media:
81
+ data["sound"] = processed_sound
82
+ data["sound_feature_masks"] = processed_sound_feature_masks
83
+ data["sound_embed_masks"] = processed_sound_embed_masks
84
+
85
+ except Exception as e:
86
+ if not self.resample_on_failure:
87
+ raise e
88
+ else:
89
+ logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.")
90
+ return self.__getitem__(random.randint(0, len(self.instances) - 1))
91
+
92
+ return data
93
+
94
+ def __len__(self) -> int:
95
+ return len(self.instances)
llava/data/builder.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import os.path as osp
9
+ from itertools import chain
10
+ from typing import Any, List, Optional
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from hydra.utils import instantiate
15
+ from torch.utils.data import ConcatDataset, Dataset
16
+ from transformers import PreTrainedTokenizer
17
+
18
+ from llava.data.datasets_mixture import DATASETS_LEGACY
19
+ from llava.train.args import DataArguments, TrainingArguments
20
+ from llava.utils import io
21
+ from llava.utils.logging import logger
22
+ import time
23
+ import numpy as np
24
+ __all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"]
25
+
26
+
27
+ def load_dataset_yaml(name):
28
+ fname = f"{name}.yaml" if not name.endswith(".yaml") else name
29
+
30
+ # yaml under llava/data/registry/datasets
31
+ repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname)
32
+ if osp.exists(repo_path):
33
+ return repo_path
34
+
35
+ # # yaml under <fs yaml path>
36
+ abs_path = osp.expanduser(fname)
37
+ if osp.exists(abs_path):
38
+ return abs_path
39
+
40
+ raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.")
41
+
42
+
43
+ def register_datasets(name: Optional[str] = None):
44
+ if name is None:
45
+ name = os.environ.get("VILA_DATASETS", "default")
46
+ logger.info(f"Registering datasets from environment: '{name}'.")
47
+ # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml"))
48
+ dataset_meta = {}
49
+ for _name in name.split(","):
50
+ yamlpath = load_dataset_yaml(_name)
51
+ logger.info(f"Registering datasets from: '{yamlpath}'.")
52
+ meta = io.load(yamlpath)
53
+ dataset_meta.update(meta)
54
+ return dataset_meta
55
+
56
+
57
+ def register_mixtures():
58
+ return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml"))
59
+
60
+
61
+ DATASETS = register_datasets()
62
+ MIXTURES = register_mixtures()
63
+
64
+
65
+ def parse_mixture(mixture: str) -> List[str]:
66
+ names = mixture.split("+") if "+" in mixture else [mixture]
67
+ while any(name in MIXTURES for name in names):
68
+ names = list(chain(*[MIXTURES.get(name, [name]) for name in names]))
69
+ return sorted(names)
70
+
71
+
72
+ class SubsetDataset(Dataset):
73
+ def __init__(self, dataset: Dataset, limit: int) -> None:
74
+ super().__init__()
75
+ self.dataset = dataset
76
+ self.limit = limit
77
+
78
+ def __len__(self) -> int:
79
+ return int(len(self.dataset) * self.limit)
80
+
81
+ def __getitem__(self, index: int) -> Any:
82
+ return self.dataset[index % len(self.dataset)]
83
+
84
+ class RepeatedDataset(Dataset):
85
+ def __init__(self, dataset: Dataset, times: int) -> None:
86
+ super().__init__()
87
+ self.dataset = dataset
88
+ self.times = times
89
+
90
+ def __len__(self) -> int:
91
+ return len(self.dataset) * self.times
92
+
93
+ def __getitem__(self, index: int) -> Any:
94
+ return self.dataset[index % len(self.dataset)]
95
+
96
+
97
+ def get_world_size():
98
+ if torch.distributed.is_initialized():
99
+ return torch.distributed.get_world_size()
100
+ else:
101
+ return 1
102
+
103
+
104
+ def build_dataset(
105
+ mixture: str,
106
+ data_args: DataArguments,
107
+ training_args: TrainingArguments,
108
+ tokenizer: PreTrainedTokenizer,
109
+ ) -> Dataset:
110
+ logger.warning(f"Training VILA with mixture '{mixture}'.")
111
+ datasets = []
112
+ dataset_rng = np.random.default_rng(1234)
113
+ for name in parse_mixture(mixture):
114
+
115
+ if "*" in name:
116
+ name, times = name.split("*")
117
+ times = int(times)
118
+ else:
119
+ times = 1
120
+ limit_dataset = False
121
+ if "#" in name:
122
+ # we limit the max length of this dataset
123
+ name, max_length_percent = name.split("#")
124
+ limit_dataset = True
125
+ if DATASETS is not None and name in DATASETS:
126
+ if name in DATASETS_LEGACY:
127
+ logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.")
128
+ dataset = instantiate(DATASETS[name], _partial_=True)(
129
+ tokenizer=tokenizer,
130
+ data_args=data_args,
131
+ global_batch_size=(
132
+ training_args.per_device_train_batch_size
133
+ # * torch.distributed.get_world_size()
134
+ * get_world_size()
135
+ * training_args.gradient_accumulation_steps
136
+ ),
137
+ )
138
+ elif name in DATASETS_LEGACY:
139
+ logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.")
140
+ dataset = build_dataset_legacy(
141
+ name,
142
+ data_args=data_args,
143
+ training_args=training_args,
144
+ tokenizer=tokenizer,
145
+ )
146
+ else:
147
+ raise ValueError(f"Dataset '{name}' is not found in the registries.")
148
+
149
+
150
+ if limit_dataset:
151
+ # we limit the max length of this dataset
152
+ max_length = int(float(int(max_length_percent) / 100.) * len(dataset))
153
+ dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.))
154
+
155
+ if times > 1:
156
+ dataset = RepeatedDataset(dataset, times)
157
+ datasets.append(dataset)
158
+ return ConcatDataset(datasets)
159
+
160
+
161
+ def build_dataset_legacy(
162
+ name: str,
163
+ data_args: DataArguments,
164
+ training_args: TrainingArguments,
165
+ tokenizer: PreTrainedTokenizer,
166
+ ) -> Dataset:
167
+ from llava.data.dataset import (
168
+ LazySupervisedDataset,
169
+ LazyWDSDataset,
170
+ )
171
+
172
+ dataset = DATASETS_LEGACY[name]
173
+ dataset_type = dataset.dataset_type
174
+ if dataset_type == "torch":
175
+ dataset_cls = LazySupervisedDataset
176
+ elif dataset_type == "wds":
177
+ dataset_cls = LazyWDSDataset
178
+ else:
179
+ raise NotImplementedError(f"{dataset_type} is not supported.")
180
+
181
+ data_args.meta_path = getattr(dataset, "meta_path", None)
182
+ data_args.caption_choice = getattr(dataset, "caption_choice", None)
183
+ data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None)
184
+ data_args.start_idx = getattr(dataset, "start_idx", None)
185
+ data_args.end_idx = getattr(dataset, "end_idx", None)
186
+
187
+ return dataset_cls(
188
+ tokenizer=tokenizer,
189
+ data_path=dataset.data_path,
190
+ image_folder=getattr(dataset, "image_path"),
191
+ data_args=data_args,
192
+ training_args=training_args,
193
+ )
llava/data/collate.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, Sequence
9
+
10
+ import torch
11
+ from transformers import PreTrainedTokenizer
12
+
13
+ from llava.constants import IGNORE_INDEX
14
+ from llava.utils.logging import logger
15
+
16
+ __all__ = ["DataCollator"]
17
+
18
+
19
+ @dataclass
20
+ class DataCollator:
21
+ tokenizer: PreTrainedTokenizer
22
+
23
+ def __init__(self, tokenizer: PreTrainedTokenizer):
24
+ super().__init__()
25
+ self.tokenizer = tokenizer
26
+
27
+ def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
28
+ # Gather everything from the batch
29
+ input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, []
30
+
31
+ media_meta = {}
32
+
33
+ media_meta["sound_feature_masks"] = []
34
+ media_meta["sound_embed_masks"] = []
35
+ media_meta["frame_times"] = []
36
+ for instance in instances:
37
+ if isinstance(instance["input_ids"], torch.Tensor):
38
+ input_ids.append(instance["input_ids"])
39
+ labels.append(instance["labels"])
40
+ for name in media:
41
+ objs = instance.get(name)
42
+ objs = objs if objs is not None else []
43
+ media[name].append([obj for obj in objs])
44
+ if instance.get("sound") is not None:
45
+ for name_k in media_meta:
46
+ if "sound" in name_k:
47
+ objs = instance.get(name_k)
48
+ media_meta[name_k].append([obj for obj in objs])
49
+ if instance.get("video") is not None or instance.get("image") is not None:
50
+ for name_k in media_meta:
51
+ if "frame" in name_k:
52
+ objs = instance.get(name_k)
53
+ media_meta[name_k].append([obj for obj in objs])
54
+ if "block_sizes" in instance:
55
+ block_sizes.append(instance["block_sizes"])
56
+ else:
57
+ block_sizes.append(
58
+ [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else []
59
+ )
60
+ else:
61
+ input_ids.extend(instance["input_ids"])
62
+ labels.extend(instance["labels"])
63
+ for name in media:
64
+ objs = instance.get(name)
65
+ objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))]
66
+ media[name].extend(objs)
67
+ if instance.get("sound") is not None:
68
+ for name_k in media_meta:
69
+ if "sound" in name_k:
70
+ objs = instance.get(name_k)
71
+ media_meta[name_k].extend(objs)
72
+ if instance.get("video") is not None or instance.get("image") is not None:
73
+ for name_k in media_meta:
74
+ if "frame" in name_k:
75
+ objs = instance.get(name_k)
76
+ media_meta[name_k].append([obj for obj in objs])
77
+ if "block_sizes" in instance:
78
+ block_sizes.extend(instance["block_sizes"])
79
+ else:
80
+ block_sizes.extend(
81
+ [[None for _ in range(len(objs))] for objs in instance.get("image")]
82
+ if instance.get("image") is not None
83
+ else [[] for _ in range(len(instance["input_ids"]))]
84
+ )
85
+
86
+ batch_size = len(input_ids)
87
+
88
+
89
+ # Check if the number of media objects (or the number of block sizes) matches the number of media tokens
90
+ for name in media:
91
+ for k in range(batch_size):
92
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
93
+ actual = len(block_sizes[k])
94
+ else:
95
+ actual = len(media[name][k])
96
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
97
+ if actual != expected:
98
+ raise ValueError(
99
+ f"Number mismatch between {name} objects and {name} tokens. "
100
+ f"There are {expected} {name} tokens but {actual} {name} objects."
101
+ )
102
+
103
+ # Batchify the inputs
104
+ input_ids = torch.nn.utils.rnn.pad_sequence(
105
+ input_ids,
106
+ batch_first=True,
107
+ padding_value=self.tokenizer.pad_token_id,
108
+ )
109
+ labels = torch.nn.utils.rnn.pad_sequence(
110
+ labels,
111
+ batch_first=True,
112
+ padding_value=IGNORE_INDEX,
113
+ )
114
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
115
+ labels = labels[:, : self.tokenizer.model_max_length]
116
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
117
+
118
+ # Truncate media objects if necessary
119
+ for name in media:
120
+ objects = []
121
+ for k in range(batch_size):
122
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
123
+ actual = len(media[name][k])
124
+ num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]])
125
+ num_small_scale_blocks = actual - num_large_scale_blocks
126
+ num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k])
127
+ expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
128
+ expected = (
129
+ sum([x * y for x, y in block_sizes[k][:expected_full_image]])
130
+ + num_small_scale_blocks_each_img * expected_full_image
131
+ )
132
+ if actual > expected:
133
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
134
+ media[name][k] = media[name][k][:expected]
135
+ objects.extend(media[name][k])
136
+ block_sizes[k] = block_sizes[k][:expected_full_image]
137
+ else:
138
+ actual = len(media[name][k])
139
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
140
+ if actual > expected:
141
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
142
+ media[name][k] = media[name][k][:expected]
143
+ objects.extend(media[name][k])
144
+ if name == "image":
145
+ block_sizes[k] = block_sizes[k][:expected]
146
+ media[name] = objects
147
+
148
+ for name in media_meta:
149
+ objects = []
150
+ for k in range(batch_size):
151
+ try:
152
+ objects.extend(media_meta[name][k])
153
+ except:
154
+ continue
155
+ media_meta[name] = objects
156
+
157
+ # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...]
158
+ block_sizes = sum(block_sizes, [])
159
+ return {
160
+ "input_ids": input_ids,
161
+ "media": media,
162
+ "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}},
163
+ "labels": labels,
164
+ "attention_mask": attention_mask,
165
+ "media_meta": media_meta,
166
+ }
llava/data/dataset.py ADDED
@@ -0,0 +1,1635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
8
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import base64
23
+ import copy
24
+ import io
25
+ import json
26
+ import os
27
+ import os.path as osp
28
+ import random
29
+ import time
30
+ import warnings
31
+ from dataclasses import dataclass
32
+ from typing import Dict, Sequence
33
+ import math
34
+ import numpy as np
35
+ import PIL
36
+ import torch
37
+ import transformers
38
+ from PIL import Image, ImageFile
39
+ from torch.utils.data import Dataset, default_collate
40
+ from transformers import PreTrainedTokenizer
41
+ from transformers import AutoFeatureExtractor
42
+ import kaldiio
43
+ import llava.data.datasets_mixture as datasets_mixture
44
+ from llava import conversation as conversation_lib
45
+ from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX
46
+ from llava.data.collate import DataCollator
47
+ from llava.mm_utils import (
48
+ load_audio,
49
+ get_num_windows,
50
+ tokenizer_image_token,
51
+ )
52
+ from torchvision import transforms
53
+ from llava.train.args import DataArguments, TrainingArguments
54
+ from llava.train.sequence_parallel import (
55
+ extract_local_from_list,
56
+ extract_local_input_ids,
57
+ extract_local_position_ids,
58
+ get_pg_manager,
59
+ )
60
+ from llava.utils.tokenizer import preprocess_conversation
61
+ # import torchaudio
62
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
63
+ import soundfile as sf
64
+ from librosa import resample as librosa_resample
65
+ import whisper
66
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
67
+ PIL.Image.MAX_IMAGE_PIXELS = 1000000000
68
+
69
+ def int16_to_float32(x):
70
+ return (x / 32767.0).astype(np.float32)
71
+
72
+
73
+ def float32_to_int16(x):
74
+ x = np.clip(x, a_min=-1., a_max=1.)
75
+ return (x * 32767.).astype(np.int16)
76
+
77
+
78
+
79
+ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
80
+ is_multimodal = data_args.is_multimodal
81
+ if not is_multimodal:
82
+ return sources
83
+
84
+ for source in sources:
85
+ concat_values = "".join([sentence["value"] for sentence in source])
86
+ for sid, sentence in enumerate(source):
87
+ # In multimodal conversations, we automatically prepend '<image>' at the start of the first sentence if it doesn't already contain one.
88
+
89
+ if DEFAULT_SOUND_TOKEN in sentence["value"]:
90
+ sentence["value"] = sentence["value"].replace(DEFAULT_SOUND_TOKEN, f"{DEFAULT_SOUND_TOKEN}\n")
91
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SOUND_TOKEN}\n\n", f"{DEFAULT_SOUND_TOKEN}\n")
92
+ if DEFAULT_SPEECH_TOKEN in sentence["value"]:
93
+ sentence["value"] = sentence["value"].replace(DEFAULT_SPEECH_TOKEN, f"{DEFAULT_SPEECH_TOKEN}\n")
94
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SPEECH_TOKEN}\n\n", f"{DEFAULT_SPEECH_TOKEN}\n")
95
+ return sources
96
+
97
+
98
+ def preprocess_plain(
99
+ sources: Sequence[str],
100
+ tokenizer: transformers.PreTrainedTokenizer,
101
+ ) -> Dict:
102
+ # add end signal and concatenate together
103
+ conversations = []
104
+ for source in sources:
105
+ assert len(source) == 2
106
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
107
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
108
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
109
+ conversations.append(conversation)
110
+ # tokenize conversations
111
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
112
+ targets = copy.deepcopy(input_ids)
113
+ for target, source in zip(targets, sources):
114
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
115
+ target[:tokenized_len] = IGNORE_INDEX
116
+
117
+ return dict(input_ids=input_ids, labels=targets)
118
+
119
+
120
+ def preprocess(
121
+ sources: Sequence[str],
122
+ tokenizer: transformers.PreTrainedTokenizer,
123
+ has_image: bool = False,
124
+ no_system_prompt: bool = False,
125
+ ) -> Dict:
126
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
127
+ return preprocess_plain(sources, tokenizer)
128
+ return default_collate(
129
+ [
130
+ preprocess_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
131
+ for conversation in sources
132
+ ]
133
+ )
134
+
135
+
136
+ class LazySupervisedDataset(Dataset):
137
+ """Dataset for supervised fine-tuning.
138
+ This class is originally implemented by the LLaVA team and modified by
139
+ Ji Lin and Haotian Tang.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ data_path: str,
145
+ image_folder: str,
146
+ tokenizer: transformers.PreTrainedTokenizer,
147
+ data_args: DataArguments,
148
+ training_args: TrainingArguments,
149
+ ):
150
+ super().__init__()
151
+ try:
152
+ with open(data_path) as fp:
153
+ list_data_dict = json.load(fp)
154
+ except:
155
+ with open(data_path) as fp:
156
+ list_data_dict = [json.loads(q) for q in fp]
157
+
158
+ # rank0_print("Formatting inputs...Skip in lazy mode")
159
+ print("Formatting inputs...Skip in lazy mode")
160
+ self.tokenizer = tokenizer
161
+ self.list_data_dict = list_data_dict
162
+ self.data_args = data_args
163
+ self.image_folder = image_folder
164
+ self.wav_processor = AutoFeatureExtractor.from_pretrained('/lustre/fsw/portfolios/adlr/users/sreyang/flamingo_v2/NV-Whisper')
165
+
166
+ def __len__(self):
167
+ return len(self.list_data_dict)
168
+
169
+ @property
170
+ def lengths(self):
171
+ length_list = []
172
+ for sample in self.list_data_dict:
173
+ img_tokens = 128 if "image" in sample else 0
174
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
175
+ return length_list
176
+
177
+ @property
178
+ def modality_lengths(self):
179
+ length_list = []
180
+ for sample in self.list_data_dict:
181
+ if 'duration' in sample.keys():
182
+ duration = sample["duration"]
183
+ else:
184
+ duration = 10.
185
+ try:
186
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + int(math.ceil(duration * 25))
187
+ cur_len = cur_len if "sound" in sample else -cur_len
188
+ length_list.append(cur_len)
189
+ except:
190
+ try:
191
+ cur_len = 0 + int(math.ceil(duration * 25))
192
+ cur_len = cur_len if "sound" in sample else -cur_len
193
+ length_list.append(cur_len)
194
+ except:
195
+ cur_len = 0 + int(math.ceil(10. * 25))
196
+ cur_len = cur_len if "sound" in sample else -cur_len
197
+ length_list.append(cur_len)
198
+ return length_list
199
+
200
+ @staticmethod
201
+ def _load_sound(sound_file, wav_processor, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=3, audio_start = 0.0):
202
+ if sound_file is None:
203
+ return None
204
+ window_length = int(window_length * sample_rate)
205
+ window_overlap = int(window_overlap * sample_rate)
206
+ max_num_window = int(max_num_window)
207
+ duration = max_num_window * (window_length - window_overlap) + window_overlap
208
+
209
+ sound_outputs = []
210
+ audio_feature_masks = []
211
+ audio_embed_masks = []
212
+
213
+ try:
214
+ sound_filename = str.split(sound_file, '/')[-1]
215
+ if '.ark' in sound_filename:
216
+ sound = kaldiio.load_mat(sound_file)
217
+ audio_data = sound[1]
218
+ audio_data=audio_data.astype(np.float16)
219
+ else:
220
+ audio_data = load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration
221
+ T = len(audio_data)
222
+ audio_data = audio_data.reshape(1, -1)
223
+ num_windows, full_length = get_num_windows(T, sample_rate, max_num_window)
224
+
225
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
226
+
227
+ for i in range(num_windows):
228
+ audio_embed_mask = torch.zeros(750)
229
+ start = i * (window_length - window_overlap)
230
+ audio_data_tensor_this = audio_data_tensor[:, start:start+window_length]
231
+ orig_length = audio_data_tensor_this.shape[1]
232
+ audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") #
233
+ sound_outputs.append(audio_data_tensor_this["input_features"])
234
+ # calculate the mask for the input melspec to Whisper
235
+ melspec_frames_this_window = int(math.ceil(orig_length / 160))
236
+ feature_attention_mask = torch.zeros(3000, dtype=torch.int32)
237
+ feature_attention_mask[:melspec_frames_this_window] = 1
238
+ audio_feature_masks.append(feature_attention_mask.unsqueeze(0))
239
+ # calculate the mask for the output embedding for use in AF2
240
+ conv_lengths = (melspec_frames_this_window - 1) // 2 + 1
241
+ output_embedding_lengths = (conv_lengths - 2) // 2 + 1
242
+ audio_embed_mask[:output_embedding_lengths] = 1
243
+ audio_embed_masks.append(audio_embed_mask)
244
+ except:
245
+ print('error loading file', sound_file)
246
+ sound_outputs.append(torch.zeros(1,128,3000))
247
+ audio_feature_masks.append(torch.zeros(1,3000, dtype=torch.int32))
248
+ audio_embed_masks.append(torch.zeros(750))
249
+
250
+ return torch.stack(sound_outputs, dim=0), torch.stack(audio_feature_masks, dim=0), torch.stack(audio_embed_masks, dim=0)
251
+
252
+ @staticmethod
253
+ def _load_speech(speech_path,sample_rate=16000):
254
+ if speech_path is None:
255
+ return None
256
+
257
+ speech_outputs = []
258
+ try:
259
+ speech = whisper.load_audio(speech_path)
260
+ speech = whisper.pad_or_trim(speech)
261
+ mel = whisper.log_mel_spectrogram(speech)
262
+ speech_outputs.append(mel.unsqueeze(0))
263
+ except:
264
+ speech_outputs.append(torch.zeros(1,80,3000))
265
+ return torch.stack(speech_outputs, dim=0)
266
+
267
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
268
+ sources = self.list_data_dict[i]
269
+ if isinstance(i, int):
270
+ sources = [sources]
271
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
272
+
273
+ import re
274
+ if "sound" in self.list_data_dict[i]:
275
+ # chat data loading
276
+ if isinstance(self.list_data_dict[i]["sound"],list):
277
+ sound_files = self.list_data_dict[i]["sound"]
278
+ conversations_raw = self.list_data_dict[i]["conversations"]
279
+
280
+ # Step 1: Extract <sound-X> tags in order of appearance
281
+ sound_tag_pattern = re.compile(r"<sound-(\d+)>")
282
+ ordered_sound_tags = []
283
+
284
+ for turn in conversations_raw:
285
+ tags = sound_tag_pattern.findall(turn["value"])
286
+ ordered_sound_tags.extend([f"<sound-{tag}>" for tag in tags])
287
+
288
+ # Step 2: Load sound tensors in the order of tags
289
+ sound_tensor = []
290
+ audio_feature_masks = []
291
+ audio_embed_masks = []
292
+ sound_token_map = {}
293
+
294
+ for tag in ordered_sound_tags:
295
+ idx = int(tag.split('-')[1][:-1])
296
+ if tag not in sound_token_map:
297
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
298
+ this_sound_tensor = this_sound_tensor.squeeze(1) # (windows x 750 x 2048)
299
+ sound_token_map[tag] = ("<sound>\n" * this_sound_tensor.shape[0]).rstrip()
300
+ sound_tensor.append(this_sound_tensor)
301
+ audio_feature_masks.append(af_mask)
302
+ audio_embed_masks.append(ae_mask)
303
+ else:
304
+ # If already loaded, still append to match sequence
305
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
306
+ this_sound_tensor = this_sound_tensor.squeeze(1)
307
+ sound_tensor.append(this_sound_tensor)
308
+ audio_feature_masks.append(af_mask)
309
+ audio_embed_masks.append(ae_mask)
310
+
311
+
312
+ # Process conversations and inject sound markers
313
+ conversation = []
314
+ for turn in conversations_raw:
315
+ role = turn["from"]
316
+ value = turn["value"]
317
+
318
+ # Replace any <sound-X> tag with corresponding repeated <sound>\n
319
+ for tag, sound_token in sound_token_map.items():
320
+ value = value.replace(tag, sound_token)
321
+
322
+ conversation.append({
323
+ "from": role,
324
+ "value": value.rstrip()
325
+ })
326
+
327
+ sources = [conversation]
328
+ sound_tensor = torch.cat(sound_tensor, dim=0)
329
+ audio_feature_masks = torch.cat(audio_feature_masks, dim=0)
330
+ audio_embed_masks = torch.cat(audio_embed_masks, dim=0)
331
+ else:
332
+ sound_file = self.list_data_dict[i]["sound"]
333
+ question = str(self.list_data_dict[i]["conversations"][0]["value"].rstrip())
334
+ answer = str(self.list_data_dict[i]["conversations"][1]["value"]).rstrip()
335
+ question = question.replace("<speech>\n", "").replace("\n<speech>", "").replace("<speech>", "")
336
+ question = question.replace("<sound>\n", "").replace("\n<sound>", "").replace("<sound>", "")
337
+ question = question.replace("<en><asr>\n", "").replace("\n<en><asr>", "").replace("<en><asr>", "")
338
+ question = question.replace("<eng><asr>\n", "").replace("\n<eng><asr>", "").replace("<eng><asr>", "")
339
+ sound_tensor, audio_feature_masks, audio_embed_masks = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
340
+ sound_tensor=sound_tensor.squeeze(1) # squeeze the irrelevant dimension which was caused due to processor getting 1 batch for processing --> (windows x 750 x 2048)
341
+ question = "<sound>\n" * sound_tensor.shape[0] + question
342
+ conversation = [
343
+ {"from": "human", "value": question},
344
+ {"from": "gpt", "value": answer},
345
+ ]
346
+
347
+ sources = [conversation]
348
+ data_dict = preprocess(
349
+ sources,
350
+ self.tokenizer,
351
+ has_image=(
352
+ "sound" in self.list_data_dict[i]
353
+ ),
354
+ )
355
+ if isinstance(i, int):
356
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
357
+
358
+ if "sound" in self.list_data_dict[i]:
359
+ data_dict["sound"] = sound_tensor
360
+ data_dict["sound_feature_masks"] = audio_feature_masks
361
+ data_dict["sound_embed_masks"] = audio_embed_masks
362
+ if "speech" in self.list_data_dict[i]:
363
+ data_dict["speech"] = speech_tensor
364
+
365
+ return data_dict
366
+
367
+
368
+ class LazyMMC4Dataset(Dataset):
369
+ """Dataset for supervised fine-tuning.
370
+ This class is implemented by Ji Lin and Haotian Tang."""
371
+
372
+ def __init__(
373
+ self,
374
+ data_path: str,
375
+ image_folder: str,
376
+ tokenizer: transformers.PreTrainedTokenizer,
377
+ data_args: DataArguments,
378
+ training_args: TrainingArguments,
379
+ image_following_text_only=False,
380
+ text_only=False,
381
+ ):
382
+ super().__init__()
383
+
384
+ import pickle
385
+
386
+ n_samples = []
387
+ # actually shards and stats info
388
+ n_shards = len(os.listdir(data_path)) // 2
389
+ # n_shards = 100
390
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
391
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
392
+
393
+ print("total MMC4 samples", sum(n_samples)) # 10,881,869
394
+
395
+ PROCESS_GROUP_MANAGER = get_pg_manager()
396
+ if PROCESS_GROUP_MANAGER is not None:
397
+ import torch.distributed as dist
398
+
399
+ sequence_parallel_size = training_args.seq_parallel_size
400
+ else:
401
+ sequence_parallel_size = 1
402
+ print("sequence_parallel_size", sequence_parallel_size)
403
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
404
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
405
+ shared_size = n_shards // world_size
406
+
407
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
408
+ self.n_samples = min(gpu_samples) * world_size # total size
409
+ self.idx_offset = rank * min(gpu_samples)
410
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
411
+ print(f" * loading data from shard {shard_start}-{shard_end}")
412
+
413
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
414
+ shard_names = shard_names[shard_start:shard_end]
415
+
416
+ full_data_list = []
417
+ # now load data
418
+ for shard_name in shard_names:
419
+ # load shard
420
+ with open(os.path.join(data_path, shard_name), "rb") as f:
421
+ data_list = pickle.load(f)
422
+
423
+ full_data_list.extend(data_list)
424
+
425
+ print(f"* loaded totally {len(full_data_list)} samples")
426
+
427
+ self.data_list = full_data_list
428
+
429
+ self.tokenizer = tokenizer
430
+ self.data_args = data_args
431
+ self.image_folder = image_folder
432
+
433
+ self.image_following_text_only = image_following_text_only
434
+ self.text_only = text_only
435
+
436
+ def __len__(self):
437
+ # return len(self.data_list)
438
+ return self.n_samples
439
+
440
+ @property
441
+ def modality_lengths(self):
442
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
443
+ length_list = []
444
+ for info in self.data_list:
445
+ num_images = min(6, len(info["image_info"]))
446
+ sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]]
447
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
448
+ cur_len = num_images * self.num_image_tokens // 2 + sum([len(x) for x in sentences])
449
+ length_list.append(cur_len)
450
+ return length_list
451
+
452
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
453
+ info = self.data_list[i - self.idx_offset]
454
+
455
+ sentences = info["text_list"]
456
+ # kentang-mit@: remove existing <image> tokens in the sentences
457
+ for ix in range(len(sentences)):
458
+ # if this is an html tag, we still preserve its semantic meaning
459
+ sentences[ix] = sentences[ix].replace("<image>", "<IMAGE>")
460
+ sim_matrix = info["similarity_matrix"] # we do not use this...
461
+
462
+ # convert images from base64 to PIL and filter based on image-text similarity
463
+ images, sentence_ixs = [], []
464
+ if not self.text_only:
465
+ for sample_image, sim_vec in zip(info["image_info"], sim_matrix):
466
+ image_base64 = sample_image["image_base64"]
467
+ rawbytes = base64.b64decode(image_base64)
468
+
469
+ sim_ix = sample_image["matched_text_index"]
470
+ # sim_ix = np.argmax(sim_vec)
471
+ # sim_score = sim_vec[sim_ix]
472
+
473
+ # filter to images >= 5KB
474
+ # if len(rawbytes) // 1000 <= 5:
475
+ # continue
476
+ # if sim_score < 0.24:
477
+ # continue
478
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
479
+
480
+ images.append(image)
481
+ sentence_ixs.append(sim_ix)
482
+
483
+ # constrain max num 6 images
484
+ max_num_images = 6
485
+ if len(images) > max_num_images:
486
+ images = images[:max_num_images]
487
+ sentence_ixs = sentence_ixs[:max_num_images]
488
+
489
+ # reorder images according to text insertion
490
+ images = [images[iii] for iii in np.argsort(sentence_ixs)]
491
+
492
+ # preprocess and tokenize text
493
+ for ix in sentence_ixs:
494
+ sentences[ix] = f"<image>\n{sentences[ix]}"
495
+
496
+ if self.image_following_text_only:
497
+ # use pad tokens to divide sentence pieces
498
+ text = self.tokenizer.pad_token.join(sentences)
499
+ else:
500
+ text = " ".join(sentences)
501
+ # whitespace cleanup
502
+ text = text.replace("<image> ", "<image>").replace(" <image>", "<image>")
503
+ text = f"{text}{self.tokenizer.eos_token}" # add eos token
504
+
505
+ if len(images) > 0:
506
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
507
+ images, block_sizes = dynamic_s2_process_images_and_prompt(
508
+ images, text, self.data_args, self.image_folder
509
+ )
510
+ elif self.data_args.image_aspect_ratio == "dynamic":
511
+ images, text = dynamic_process_images_and_prompt(
512
+ images, text, self.data_args, self.image_folder, max_tiles=6
513
+ )
514
+ else:
515
+ images = torch.stack([process_image(image, self.data_args, self.image_folder) for image in images])
516
+
517
+ # the same size for all images, so we concat
518
+ # cur_token_len = (
519
+ # images[0].shape[-2] // self.multimodal_cfg["patch_size"]
520
+ # ) * (images[0].shape[-1] // self.multimodal_cfg["patch_size"])
521
+ # cur_token_len += self.multimodal_cfg["n_extra_patch"]
522
+ else:
523
+ images = None
524
+ # cur_token_len = 0
525
+
526
+ input_ids = tokenizer_image_token(
527
+ text,
528
+ self.tokenizer,
529
+ return_tensors="pt",
530
+ )
531
+
532
+ image_token_id = self.tokenizer.media_token_ids["image"]
533
+
534
+ # now check the case where the last token is image patch token
535
+ if input_ids[-1] == image_token_id: # need to remove one last image
536
+ last_non_im_patch_indices = torch.where(input_ids != image_token_id)[0][-1] + 1
537
+ input_ids = input_ids[:last_non_im_patch_indices]
538
+
539
+ n_im_patch = (input_ids == image_token_id).sum().item()
540
+
541
+ if self.data_args.image_aspect_ratio != "dynamic_s2":
542
+ images = images[:n_im_patch]
543
+ assert len(images) == n_im_patch, print(text, input_ids)
544
+ assert len(input_ids.shape) == 1, "Unexpected shape of 'input_ids' from MMC4."
545
+ input_ids = (
546
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids])
547
+ if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id
548
+ else input_ids
549
+ )
550
+ targets = input_ids.clone()
551
+
552
+ if self.image_following_text_only: # keep only text after leading image token
553
+ # remove loss for any token before the first <image> token
554
+ label_idx = 0
555
+ while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id:
556
+ targets[label_idx] = IGNORE_INDEX
557
+ label_idx += 1
558
+
559
+ pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0]
560
+
561
+ pad_token_idxs = torch.where(targets == pad_token)[0]
562
+ for pad_token_idx in pad_token_idxs:
563
+ token_idx = pad_token_idx + 1
564
+ while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id:
565
+ targets[token_idx] = IGNORE_INDEX
566
+ token_idx += 1
567
+ # do not train on padding tokens
568
+ targets[targets == pad_token] = IGNORE_INDEX
569
+
570
+ # mask image tokens is unnecessary for llava-1.5
571
+ # targets[targets == IMAGE_TOKEN_INDEX] = IGNORE_INDEX
572
+ # print(input_ids.shape)
573
+
574
+ data_dict = dict(input_ids=input_ids, labels=targets, image=images)
575
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
576
+ data_dict["block_sizes"] = block_sizes
577
+
578
+ return data_dict
579
+
580
+
581
+ class LazyCoyoDataset(Dataset):
582
+ """Dataset for supervised fine-tuning.
583
+ This class is implemented by Ji Lin and Haotian Tang."""
584
+
585
+ num_image_tokens = 576
586
+
587
+ def __init__(
588
+ self,
589
+ data_path: str,
590
+ image_folder: str,
591
+ tokenizer: transformers.PreTrainedTokenizer,
592
+ data_args: DataArguments,
593
+ training_args: TrainingArguments,
594
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
595
+ n_samples_per_idx=4,
596
+ ):
597
+ super().__init__()
598
+
599
+ import pickle
600
+
601
+ n_samples = []
602
+ # actually shards and stats info
603
+ n_shards = len(os.listdir(data_path)) // 2
604
+ # n_shards = 100
605
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
606
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
607
+
608
+ print("total COYO samples", sum(n_samples))
609
+
610
+ PROCESS_GROUP_MANAGER = get_pg_manager()
611
+ if PROCESS_GROUP_MANAGER is not None:
612
+ import torch.distributed as dist
613
+
614
+ sequence_parallel_size = training_args.seq_parallel_size
615
+ else:
616
+ sequence_parallel_size = 1
617
+ print("sequence_parallel_size", sequence_parallel_size)
618
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
619
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
620
+ shared_size = n_shards // world_size
621
+
622
+ gpu_samples = [
623
+ sum(n_samples[i * shared_size : (i + 1) * shared_size]) // n_samples_per_idx for i in range(world_size)
624
+ ]
625
+ self.n_samples = min(gpu_samples) * world_size # total size
626
+ self.idx_offset = rank * min(gpu_samples)
627
+
628
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
629
+ print(f" * loading data from shard {shard_start}-{shard_end}")
630
+
631
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
632
+ shard_names = shard_names[shard_start:shard_end]
633
+
634
+ full_data_list = []
635
+ # now load data
636
+ for shard_name in shard_names:
637
+ # load shard
638
+ with open(os.path.join(data_path, shard_name), "rb") as f:
639
+ shard_data = pickle.load(f)
640
+ random.seed(42)
641
+ if "mmc4" in data_path:
642
+ random.shuffle(shard_data) # shuffle for MMC4cap only
643
+ full_data_list.extend(shard_data)
644
+
645
+ print(f"* loaded totally {len(full_data_list)} samples")
646
+
647
+ # now pack the samples into groups
648
+ n_groups = len(full_data_list) // n_samples_per_idx
649
+ full_data_list = [
650
+ full_data_list[i : i + n_samples_per_idx] for i in range(0, len(full_data_list), n_samples_per_idx)
651
+ ]
652
+ if len(full_data_list[-1]) < n_samples_per_idx:
653
+ full_data_list = full_data_list[:-1]
654
+ assert len(full_data_list) == n_groups
655
+ print(f"split into {n_groups} groups")
656
+
657
+ self.data_list = full_data_list
658
+
659
+ self.tokenizer = tokenizer
660
+ self.data_args = data_args
661
+ self.image_folder = image_folder
662
+
663
+ def __len__(self):
664
+ # return len(self.data_list)
665
+ return self.n_samples
666
+
667
+ @property
668
+ def modality_lengths(self):
669
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
670
+ length_list = []
671
+ for samples in self.data_list:
672
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
673
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
674
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
675
+ length_list.append(cur_len)
676
+ return length_list
677
+
678
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
679
+ CONCAT_SAMPLES = False
680
+ info_list = self.data_list[i - self.idx_offset]
681
+
682
+ text_list = []
683
+ image_list = []
684
+
685
+ for sample in info_list:
686
+ caption_key = (
687
+ "text" if "text" in sample else "caption"
688
+ ) # kentang-mit@: remove existing <image> tokens in the sentences
689
+ # kentang-mit@: remove existing <image> token.
690
+ # if this is an html tag, we still preserve its semantic meaning
691
+ sample[caption_key] = sample[caption_key].replace("<image>", "<IMAGE>")
692
+ text_list.append(DEFAULT_IMAGE_TOKEN + "\n" + sample[caption_key] + self.tokenizer.eos_token)
693
+ if "image" in sample:
694
+ image_base64 = sample["image"]
695
+ rawbytes = base64.b64decode(image_base64)
696
+ else:
697
+ rawbytes = sample["rawbytes"]
698
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
699
+ image_list.append(image)
700
+
701
+ image_list = torch.stack([process_image(image, self.data_args, self.image_folder) for image in image_list])
702
+
703
+ if CONCAT_SAMPLES:
704
+ # into <image>cap<eos><image>cap<eos>...
705
+ text_list = "".join(text_list)
706
+
707
+ input_ids = self.tokenizer(
708
+ text_list,
709
+ return_tensors="pt",
710
+ padding="longest",
711
+ max_length=self.tokenizer.model_max_length,
712
+ truncation=True,
713
+ ).input_ids # 4, seq_len
714
+
715
+ input_ids = input_ids[0]
716
+
717
+ else:
718
+ input_ids = [
719
+ tokenizer_image_token(
720
+ prompt,
721
+ self.tokenizer,
722
+ return_tensors="pt",
723
+ )
724
+ for prompt in text_list
725
+ ]
726
+ # print([x.shape[0] for x in input_ids], [len(x.split()) for x in text_list], [len(re.findall(r"<image[^>]*>", x)) for x in text_list])
727
+
728
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
729
+ # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
730
+ # )
731
+
732
+ targets = copy.deepcopy(input_ids)
733
+ for i in range(len(targets)):
734
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
735
+
736
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
737
+
738
+
739
+ class LazyWDSDataset(Dataset):
740
+ """Dataset for supervised fine-tuning.
741
+ This class is implemented by Ji Lin and Ligeng Zhu."""
742
+
743
+ def __init__(
744
+ self,
745
+ data_path: str,
746
+ tokenizer: transformers.PreTrainedTokenizer,
747
+ data_args: DataArguments,
748
+ image_folder: str,
749
+ training_args: TrainingArguments,
750
+ ):
751
+ super().__init__()
752
+ n_samples = []
753
+ n_shards = len(os.listdir(data_path)) // 3
754
+ for shard in range(n_shards):
755
+ with open(os.path.join(data_path, f"{shard:05d}_stats.json")) as f:
756
+ info = json.load(f)
757
+ n_samples.append(info["successes"])
758
+
759
+ # print(f"[DEBUG] {data_path} total samples", sum(n_samples)) # 10,881,869
760
+
761
+ PROCESS_GROUP_MANAGER = get_pg_manager()
762
+ if PROCESS_GROUP_MANAGER is not None:
763
+ import torch.distributed as dist
764
+
765
+ sequence_parallel_size = training_args.seq_parallel_size
766
+ else:
767
+ sequence_parallel_size = 1
768
+ print("sequence_parallel_size", sequence_parallel_size)
769
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
770
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
771
+ shared_size = n_shards // world_size
772
+ print("rank", rank, "world_size", world_size, "shared_size", shared_size)
773
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
774
+ self.n_samples = min(gpu_samples) * world_size # total size
775
+ self.idx_offset = rank * min(gpu_samples)
776
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
777
+ print(f" * loading data from shard {shard_start}-{shard_end}")
778
+
779
+ tar_list = [f"{shard_idx:05d}.tar" for shard_idx in range(shard_start, shard_end)]
780
+
781
+ self.data_list = []
782
+ t1 = time.time()
783
+ for tar in tar_list:
784
+ tmp_path = f"/tmp/ccs{tar}"
785
+ tar_path = os.path.join(data_path, tar)
786
+
787
+ if PROCESS_GROUP_MANAGER is not None:
788
+ dist.barrier()
789
+ if PROCESS_GROUP_MANAGER.sp_rank == 0:
790
+ os.makedirs(tmp_path, exist_ok=True)
791
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
792
+ dist.barrier()
793
+ else:
794
+ os.makedirs(tmp_path, exist_ok=True)
795
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
796
+
797
+ txt_list = [f for f in os.listdir(tmp_path) if f.endswith(".txt")]
798
+
799
+ for txt in txt_list:
800
+ caption = open(os.path.join(tmp_path, txt)).read().strip()
801
+ image_path = os.path.join(tmp_path, txt.split(".")[0] + ".jpg")
802
+ self.data_list.append({"caption": caption, "image": image_path})
803
+ t2 = time.time()
804
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
805
+
806
+ self.tokenizer = tokenizer
807
+ self.data_args = data_args
808
+ self.image_folder = image_folder
809
+
810
+ def __len__(self):
811
+ return self.n_samples
812
+
813
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
814
+
815
+ # print("i", i, "idx_offset", self.idx_offset, "len", len(self.data_list))
816
+ info = self.data_list[i - self.idx_offset]
817
+ caption, image_path = info["caption"], info["image"]
818
+
819
+ rand_prompt = "<image>\n"
820
+ sources = [
821
+ {
822
+ "image": image_path,
823
+ "conversations": [
824
+ {"from": "human", "value": rand_prompt},
825
+ {"from": "gpt", "value": caption},
826
+ ],
827
+ }
828
+ ]
829
+
830
+ # one example of sources
831
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n<image>'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
832
+ if "image" in sources[0]:
833
+ image = process_image(sources[0]["image"], self.data_args, self.image_folder)
834
+ image = torch.unsqueeze(image, dim=0)
835
+ # now random pick some context samples for training
836
+ if hasattr(self.data_args, "num_shots"):
837
+ if self.data_args.num_shots > 0:
838
+ raise NotImplementedError
839
+ else:
840
+ raise NotImplementedError
841
+
842
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
843
+
844
+ if isinstance(i, int):
845
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
846
+
847
+ # image exist in the data
848
+ if image is not None:
849
+ data_dict["image"] = image
850
+ else:
851
+ raise NotImplementedError
852
+
853
+ return data_dict
854
+
855
+
856
+ class LazyCCSWebDataset(Dataset):
857
+ """Dataset for supervised fine-tuning.
858
+ This class is implemented by Ligeng Zhu."""
859
+
860
+ def __init__(
861
+ self,
862
+ data_path: str,
863
+ image_folder: str,
864
+ tokenizer: transformers.PreTrainedTokenizer,
865
+ data_args: DataArguments,
866
+ training_args: TrainingArguments,
867
+ ):
868
+ super().__init__()
869
+ t1 = time.time()
870
+
871
+ from llava.data.simple_vila_webdataset import VILAWebDataset
872
+
873
+ print("[DEBUG] ", osp.abspath(data_path))
874
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path))
875
+
876
+ t2 = time.time()
877
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
878
+
879
+ self.tokenizer = tokenizer
880
+ self.data_args = data_args
881
+
882
+ def __len__(self):
883
+ return len(self.dataset)
884
+
885
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
886
+ # info = self.data_list[i - self.idx_offset]
887
+ # caption, image_path = info["caption"], info["image"]
888
+ info = self.dataset[i]
889
+ if ".jpg" in info:
890
+ caption, image_path = info[".txt"], info[".jpg"]
891
+ elif ".png" in info:
892
+ caption, image_path = info[".txt"], info[".png"]
893
+ elif ".webp" in info:
894
+ caption, image_path = info[".txt"], info[".webp"]
895
+ elif ".bmp" in info:
896
+ caption, image_path = info[".txt"], info[".bmp"]
897
+ elif ".tiff" in info:
898
+ caption, image_path = info[".txt"], info[".tiff"]
899
+ else:
900
+ print(info.keys())
901
+ print(info)
902
+ raise KeyError
903
+
904
+ caption = caption.replace("<image>", "<IMAGE>")
905
+ if isinstance(image_path, io.BytesIO):
906
+ image_path = Image.open(image_path).convert("RGB")
907
+
908
+ if not isinstance(image_path, PIL.Image.Image):
909
+ print(image_path)
910
+ print(info.keys())
911
+ print(type(image_path))
912
+ raise NotImplementedError
913
+
914
+ rand_prompt = "<image>\n"
915
+ sources = [
916
+ {
917
+ "image": image_path,
918
+ "conversations": [
919
+ {"from": "human", "value": rand_prompt},
920
+ {"from": "gpt", "value": caption},
921
+ ],
922
+ }
923
+ ]
924
+
925
+ # one example of sources
926
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n<image>'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
927
+ if "image" in sources[0]:
928
+ image = process_image(sources[0]["image"], self.data_args, image_folder=None)
929
+ image = torch.unsqueeze(image, dim=0)
930
+ # now random pick some context samples for training
931
+ if hasattr(self.data_args, "num_shots"):
932
+ if self.data_args.num_shots > 0:
933
+ raise NotImplementedError
934
+ else:
935
+ raise NotImplementedError
936
+
937
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
938
+
939
+ if isinstance(i, int):
940
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
941
+
942
+ # image exist in the data
943
+ if image is not None:
944
+ data_dict["image"] = image
945
+ else:
946
+ raise NotImplementedError
947
+
948
+ return data_dict
949
+
950
+
951
+ from functools import lru_cache
952
+
953
+
954
+ @lru_cache(maxsize=16)
955
+ def lru_json_load(fpath):
956
+ with open(fpath) as fp:
957
+ return json.load(fp)
958
+
959
+
960
+ class LazyCoyoWebDataset(Dataset):
961
+ """Dataset for supervised fine-tuning.
962
+ This class is implemented by Ligeng Zhu."""
963
+
964
+ num_image_tokens = 576
965
+
966
+ def __init__(
967
+ self,
968
+ data_path: str,
969
+ image_folder: str,
970
+ tokenizer: transformers.PreTrainedTokenizer,
971
+ data_args: DataArguments,
972
+ training_args: TrainingArguments,
973
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
974
+ n_samples_per_idx=4,
975
+ ):
976
+ super().__init__()
977
+
978
+ from llava.data.simple_vila_webdataset import VILAWebDataset
979
+
980
+ print("[DEBUG] ", osp.abspath(data_path))
981
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path), meta_path=data_args.meta_path)
982
+
983
+ if data_args.start_idx >= 0 and data_args.end_idx >= 0:
984
+ # Ligeng: support slicing for ablate different subsets.
985
+ total = len(self.dataset)
986
+ start_idx = int(total * data_args.start_idx)
987
+ end_idx = int(total * data_args.end_idx)
988
+ print(f"loading subset from {start_idx} to {end_idx}, total {total}")
989
+ self.dataset = torch.utils.data.Subset(self.dataset, range(start_idx, end_idx))
990
+
991
+ # For caption choice,
992
+ # if None: use original caption
993
+ # if a folder path: use specified caption to override original one (choice1)
994
+ # if a folder path: use specified caption and concat with original one (choice2)
995
+ self.caption_choice = None
996
+ self.caption_choice_2 = None
997
+ self.data_path = data_path
998
+
999
+ if data_args.caption_choice is not None:
1000
+ self.caption_choice = data_args.caption_choice
1001
+ print("[recap] Override coyo caption using ", self.caption_choice)
1002
+
1003
+ if data_args.caption_choice_2 is not None:
1004
+ self.caption_choice_2 = data_args.caption_choice_2
1005
+ print("[recapv2] Override coyo caption using ", self.caption_choice_2)
1006
+
1007
+ print("total samples", len(self.dataset))
1008
+ PROCESS_GROUP_MANAGER = get_pg_manager()
1009
+ if PROCESS_GROUP_MANAGER is not None:
1010
+ import torch.distributed as dist
1011
+
1012
+ sequence_parallel_size = training_args.seq_parallel_size
1013
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
1014
+ else:
1015
+ sequence_parallel_size = 1
1016
+ print("sequence_parallel_size", sequence_parallel_size)
1017
+ rank = (
1018
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
1019
+ ) # int(os.environ["RANK"])
1020
+ world_size = (
1021
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
1022
+ ) # int(os.environ["WORLD_SIZE"])
1023
+ print(
1024
+ "rank",
1025
+ rank,
1026
+ "world_size",
1027
+ world_size,
1028
+ )
1029
+
1030
+ self.n_samples_per_idx = n_samples_per_idx
1031
+ # self.n_samples = len(self.dataset) // n_samples_per_idx
1032
+ self.tokenizer = tokenizer
1033
+ self.data_args = data_args
1034
+
1035
+ def __len__(self):
1036
+ return len(self.dataset) // self.n_samples_per_idx
1037
+
1038
+ @property
1039
+ def modality_lengths(self):
1040
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
1041
+ length_list = []
1042
+ for samples in self.data_list:
1043
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
1044
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
1045
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
1046
+ length_list.append(cur_len)
1047
+ return length_list
1048
+
1049
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
1050
+ CONCAT_SAMPLES = False
1051
+ # info_list = self.dataset[i - self.idx_offset]
1052
+
1053
+ begin_idx, end_idx = (
1054
+ i * self.n_samples_per_idx,
1055
+ (i + 1) * self.n_samples_per_idx,
1056
+ )
1057
+ end_idx = min(end_idx, len(self.dataset))
1058
+
1059
+ text_list = []
1060
+ image_list = []
1061
+
1062
+ for idx in range(begin_idx, end_idx):
1063
+ info = self.dataset[idx]
1064
+ if ".jpg" in info:
1065
+ caption, image_path = info[".txt"], info[".jpg"]
1066
+ elif ".png" in info:
1067
+ caption, image_path = info[".txt"], info[".png"]
1068
+ elif ".webp" in info:
1069
+ caption, image_path = info[".txt"], info[".webp"]
1070
+ elif ".bmp" in info:
1071
+ caption, image_path = info[".txt"], info[".bmp"]
1072
+ elif ".tiff" in info:
1073
+ caption, image_path = info[".txt"], info[".tiff"]
1074
+ else:
1075
+ print(info.keys())
1076
+ print(info)
1077
+ raise KeyError
1078
+
1079
+ if self.caption_choice is not None:
1080
+ # load new captions
1081
+ shard = info["__shard__"]
1082
+ url = info[".json"]["url"]
1083
+ tar_name = osp.relpath(osp.realpath(shard), osp.realpath(self.data_path))
1084
+ # tar_name = osp.dirname(shard)
1085
+ shard_json_path = osp.join(self.caption_choice, tar_name + ".json")
1086
+ try:
1087
+ shard_json = lru_json_load(shard_json_path)
1088
+ try:
1089
+ caption = shard_json[url]["output"]
1090
+ except KeyError:
1091
+ print(f"{url} not in caption. fallback to original caption temporarially")
1092
+ except:
1093
+ print(f"shard_json_path {shard_json_path} not found. fallback to original caption temporarially")
1094
+ caption = caption.replace("<image>", "<IMAGE>")
1095
+ text_list.append(DEFAULT_IMAGE_TOKEN + caption + self.tokenizer.eos_token)
1096
+
1097
+ if isinstance(image_path, io.BytesIO):
1098
+ image_path = Image.open(image_path).convert("RGB")
1099
+
1100
+ if not isinstance(image_path, PIL.Image.Image):
1101
+ print(image_path)
1102
+ print(info.keys())
1103
+ print(type(image_path))
1104
+ raise NotImplementedError
1105
+
1106
+ image_list.append(image_path)
1107
+
1108
+ # image_list = torch.stack([process_image(image, self.data_args, image_folder=None) for image in image_list])
1109
+ # NOTE(fix by ligeng)
1110
+ # now image_list should return a list of image tensor where each has a dimension of (1, c, h, w)
1111
+ image_list = [process_image(image, self.data_args, image_folder=None).unsqueeze(0) for image in image_list]
1112
+
1113
+ if CONCAT_SAMPLES:
1114
+ # into <image>cap<eos><image>cap<eos>...
1115
+ text_list = "".join(text_list)
1116
+
1117
+ input_ids = self.tokenizer(
1118
+ text_list,
1119
+ return_tensors="pt",
1120
+ padding="longest",
1121
+ max_length=self.tokenizer.model_max_length,
1122
+ truncation=True,
1123
+ ).input_ids # 4, seq_len
1124
+
1125
+ input_ids = input_ids[0]
1126
+ else:
1127
+ input_ids = [
1128
+ tokenizer_image_token(
1129
+ prompt,
1130
+ self.tokenizer,
1131
+ return_tensors="pt",
1132
+ )
1133
+ for prompt in text_list
1134
+ ]
1135
+ input_ids = [
1136
+ (
1137
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids_i])
1138
+ if input_ids_i[0] != self.tokenizer.bos_token_id
1139
+ else input_ids_i
1140
+ )
1141
+ for input_ids_i in input_ids
1142
+ ]
1143
+
1144
+ targets = copy.deepcopy(input_ids)
1145
+ for i in range(len(targets)):
1146
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
1147
+
1148
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
1149
+
1150
+
1151
+ class LazyVideoWebDataset(Dataset):
1152
+ """Dataset for supervised fine-tuning."""
1153
+
1154
+ def __init__(
1155
+ self,
1156
+ data_path: str,
1157
+ image_folder: str,
1158
+ tokenizer: transformers.PreTrainedTokenizer,
1159
+ data_args: DataArguments,
1160
+ training_args: TrainingArguments,
1161
+ # cache_path: str,
1162
+ # n_samples_per_idx=4,
1163
+ ):
1164
+ super().__init__()
1165
+
1166
+ # from llava.data.simple_video_dataset import SimpleVideoDataset
1167
+
1168
+ from llava.data.simple_vila_webdataset import VILAWebDataset
1169
+
1170
+ print("[DEBUG] ", osp.abspath(data_path))
1171
+ self.dataset = VILAWebDataset(
1172
+ data_path=osp.abspath(data_path),
1173
+ meta_path=f"{osp.abspath(data_path)}/wids-meta.json",
1174
+ # cache_dir=cache_path,
1175
+ )
1176
+
1177
+ # None: use original caption
1178
+ # Folder path: use original caption
1179
+ self.caption_choice = None
1180
+ self.data_path = data_path
1181
+
1182
+ if data_args.caption_choice is not None:
1183
+ self.caption_choice = data_args.caption_choice
1184
+ print("[recap] Override LazyVideo caption using ", self.caption_choice)
1185
+
1186
+ print("total samples", len(self.dataset))
1187
+ # InternVid: TODO
1188
+ PROCESS_GROUP_MANAGER = get_pg_manager()
1189
+ if PROCESS_GROUP_MANAGER is not None:
1190
+ import torch.distributed as dist
1191
+
1192
+ sequence_parallel_size = training_args.seq_parallel_size
1193
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
1194
+ else:
1195
+ sequence_parallel_size = 1
1196
+ print("sequence_parallel_size", sequence_parallel_size)
1197
+ rank = (
1198
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
1199
+ ) # int(os.environ["RANK"])
1200
+ world_size = (
1201
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
1202
+ ) # int(os.environ["WORLD_SIZE"])
1203
+ print(
1204
+ "rank",
1205
+ rank,
1206
+ "world_size",
1207
+ world_size,
1208
+ )
1209
+ self.rank = rank
1210
+ # rank = int(os.environ["RANK"]) if "RANK" in os.environ else 2
1211
+ # world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 32
1212
+
1213
+ self.tokenizer = tokenizer
1214
+ self.data_args = data_args
1215
+
1216
+ self.missing_uids = set()
1217
+
1218
+ def __len__(self):
1219
+ return len(self.dataset)
1220
+
1221
+ @property
1222
+ def modality_lengths(self):
1223
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
1224
+ length_list = []
1225
+ for samples in self.data_list:
1226
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
1227
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
1228
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
1229
+ length_list.append(cur_len)
1230
+ return length_list
1231
+
1232
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
1233
+ ADD_TEXT_PROMPT = False
1234
+ num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
1235
+ loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0
1236
+
1237
+ info = self.dataset[i]
1238
+
1239
+ caption = ""
1240
+ # print(info)
1241
+ if ".mp4" in info:
1242
+ caption, video_path = info[".txt"], info[".mp4"]
1243
+ else:
1244
+ video_path = None
1245
+ caption = "Empty video."
1246
+
1247
+ images, frames_loaded, _ = LazySupervisedDataset._load_video(
1248
+ video_path, num_video_frames, loader_fps, self.data_args
1249
+ )
1250
+
1251
+ if frames_loaded == 0:
1252
+ caption = "Empty video."
1253
+
1254
+ if self.caption_choice is not None:
1255
+ shard = info["__shard__"]
1256
+ uuid = osp.join(info["__shard__"], info["__key__"])
1257
+ url = info["__key__"]
1258
+ tar_name = osp.basename(info["__shard__"])
1259
+
1260
+ try:
1261
+ shard_json_path = osp.join(self.caption_choice, tar_name.replace(".tar", ".json"))
1262
+ shard_json = lru_json_load(shard_json_path)
1263
+ caption = shard_json[url]["summary"]["output"]
1264
+ except (KeyError, FileNotFoundError, json.decoder.JSONDecodeError):
1265
+ if uuid not in self.missing_uids:
1266
+ print("override caption not found for ", uuid)
1267
+ self.missing_uids.add(uuid)
1268
+
1269
+ # print(f"[DEBUG {uuid}]", caption)
1270
+
1271
+ frames_loaded_successfully = len(images)
1272
+ if caption is None:
1273
+ caption = ""
1274
+ prompt = "<image>\n" * frames_loaded_successfully + caption
1275
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
1276
+
1277
+ input_ids = tokenizer_image_token(
1278
+ prompt,
1279
+ self.tokenizer,
1280
+ return_tensors="pt",
1281
+ )
1282
+ targets = copy.deepcopy(input_ids)
1283
+ data_dict = dict(input_ids=input_ids, labels=targets, image=image_tensor)
1284
+
1285
+ return data_dict
1286
+
1287
+
1288
+ class DataCollatorForSupervisedDatasetSeqParallel:
1289
+ """Collate examples for supervised fine-tuning.
1290
+ This class is originally implemented by the LLaVA team and
1291
+ modified by Haotian Tang."""
1292
+
1293
+ def __init__(
1294
+ self,
1295
+ tokenizer: transformers.PreTrainedTokenizer,
1296
+ data_args: DataArguments,
1297
+ training_args: TrainingArguments,
1298
+ sp_degree: int,
1299
+ sp_rank: int,
1300
+ ring_degree: int,
1301
+ ring_type: str,
1302
+ ):
1303
+ self.tokenizer = tokenizer
1304
+ self.data_args = data_args
1305
+ self.training_args = training_args
1306
+ self.sp_degree = sp_degree
1307
+ self.sp_rank = sp_rank
1308
+ self.ring_degree = ring_degree
1309
+ self.ring_type = ring_type
1310
+
1311
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
1312
+ input_ids, labels, images = [], [], []
1313
+ image_token_id = self.tokenizer.media_token_ids["image"]
1314
+ video_token_id = self.tokenizer.media_token_ids["video"]
1315
+
1316
+ for instance in instances:
1317
+ if not isinstance(instance["input_ids"], list):
1318
+ input_ids.append(instance["input_ids"])
1319
+ else:
1320
+ input_ids += instance["input_ids"]
1321
+ if not isinstance(instance["labels"], list):
1322
+ labels.append(instance["labels"])
1323
+ else:
1324
+ labels += instance["labels"]
1325
+ # Note (kentang-mit@: we do not directly push tensors to
1326
+ # images, but list of tensors.
1327
+ if "video" in instance:
1328
+ instance["image"] = torch.cat(instance["video"])
1329
+ video_id_pos = torch.where(input_ids[-1] == video_token_id)[0][0]
1330
+ replace_ids = torch.Tensor(
1331
+ ([image_token_id] + self.tokenizer.encode("\n")) * instance["image"].shape[0],
1332
+ device=input_ids[-1].device,
1333
+ )
1334
+ input_ids[-1] = torch.cat(
1335
+ [input_ids[-1][:video_id_pos], replace_ids, input_ids[-1][video_id_pos + 1 :]]
1336
+ ).to(input_ids[-1].dtype)
1337
+ labels[-1] = torch.cat(
1338
+ [
1339
+ labels[-1][:video_id_pos],
1340
+ torch.Tensor([IGNORE_INDEX] * instance["image"].shape[0] * 2),
1341
+ labels[-1][video_id_pos + 1 :],
1342
+ ]
1343
+ ).to(labels[-1].dtype)
1344
+ instance.pop("video")
1345
+
1346
+ if "image" in instance:
1347
+ cur_image = instance["image"]
1348
+ assert len(cur_image.shape) == 4
1349
+ # n_images, 3, size, size
1350
+ if cur_image.shape[0] == 0:
1351
+ warnings.warn("loaded one sample without images.")
1352
+ if not isinstance(instance["input_ids"], list):
1353
+ # datasets other than coyo, not packing >1 samples together
1354
+ images.append(cur_image)
1355
+ else:
1356
+ # coyo-like datasets
1357
+ images.extend(cur_image.chunk(cur_image.size(0), 0))
1358
+ else:
1359
+ warnings.warn("loaded one sample without images.")
1360
+ images.append([])
1361
+ # kentang-mit@: we need to make sure these two lists have
1362
+ # the same length. We will use input_ids to filter out images corresponding
1363
+ # to truncated <image> tokens later.
1364
+
1365
+ max_num_images = max([len(_images) for _images in images])
1366
+ for _images, _input_ids in zip(images, input_ids):
1367
+ assert (
1368
+ len(_images) == (_input_ids == image_token_id).sum().item()
1369
+ ), f"Number mismatch between images and placeholder image tokens in 'len(_images) == (_input_ids == image_token_id).sum().item()'.\
1370
+ Expect to have {len(_images)} images but only found {(_input_ids == image_token_id).sum().item()} images in tokens. \
1371
+ Error input_ids: {_input_ids} {self.tokenizer.decode([x if x != -200 else 200 for x in _input_ids])}"
1372
+
1373
+ NUM_TOKENS_PER_IMAGE = self.data_args.num_image_tokens
1374
+ if hasattr(self.data_args.image_processor, "crop_size"):
1375
+ crop_size = self.data_args.image_processor.crop_size
1376
+ else:
1377
+ crop_size = self.data_args.image_processor.size
1378
+
1379
+ # Init the padding sample
1380
+ seq_id = 0
1381
+ while seq_id < len(input_ids):
1382
+ # Skip the samples without images
1383
+ dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device)
1384
+ # dummy input_ids include one bos, one image token, and one eos
1385
+ dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3])
1386
+ dummy_input_ids[0] = self.tokenizer.bos_token_id
1387
+ dummy_input_ids[1] = image_token_id
1388
+ dummy_input_ids[2] = self.tokenizer.eos_token_id
1389
+ dummy_labels = copy.deepcopy(dummy_input_ids)
1390
+ dummy_labels[:2] = IGNORE_INDEX
1391
+ dummy_seqlen = NUM_TOKENS_PER_IMAGE + 2 # TODO: Check the hard coding of 2
1392
+ dummy_position_ids = torch.arange(start=0, end=dummy_seqlen, dtype=torch.int32)
1393
+ break
1394
+
1395
+ # Sort with the real length of the sequence
1396
+ combined = sorted(
1397
+ zip(input_ids, labels, images),
1398
+ key=lambda x: len(x[2]) * (NUM_TOKENS_PER_IMAGE - 1) + x[0].size(-1),
1399
+ reverse=True, # Start Packing from the sequence with most images.
1400
+ )
1401
+ sorted_ids, sorted_labels, sorted_images = zip(*combined)
1402
+ sorted_ids, sorted_labels, sorted_images = list(sorted_ids), list(sorted_labels), list(sorted_images)
1403
+ max_seq_length = self.tokenizer.model_max_length # len(sorted_ids[0])
1404
+ max_sample_len = 0
1405
+
1406
+ batches = []
1407
+ label_batches = []
1408
+ position_ids = []
1409
+ batch_images = []
1410
+ seqlens_in_batch = []
1411
+
1412
+ i = 0
1413
+ while i < len(sorted_ids):
1414
+ current_batch = torch.tensor([], dtype=torch.int32)
1415
+ current_label_batch = torch.tensor([], dtype=torch.int32)
1416
+ current_position_ids = torch.tensor([], dtype=torch.int32)
1417
+ current_batch_images = []
1418
+ current_num_images = 0
1419
+ current_len = 0
1420
+ current_num_samples = 0
1421
+
1422
+ # Pack a few samples into one sample
1423
+ while i < len(sorted_ids):
1424
+ num_images = (sorted_ids[i] == image_token_id).sum().item()
1425
+ num_image_tokens_added = num_images * (NUM_TOKENS_PER_IMAGE - 1)
1426
+ num_incoming_tokens = sorted_ids[i].size(-1) + num_image_tokens_added
1427
+
1428
+ # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree`
1429
+ if self.ring_degree > 1:
1430
+ RING_PAD_TOKEN_INDEX = 2
1431
+ if self.ring_type == "ring_varlen":
1432
+ if num_incoming_tokens % self.sp_degree != 0:
1433
+ pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree
1434
+ num_incoming_tokens += pad_len
1435
+ # pad `input_ids`
1436
+ pad_tensor = torch.full(
1437
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
1438
+ )
1439
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
1440
+
1441
+ # pad `label`
1442
+ pad_label_tensor = torch.full(
1443
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
1444
+ )
1445
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
1446
+ elif self.ring_type == "zigzag_ring_varlen":
1447
+ self.zigzag_sp_degree = self.sp_degree * 2
1448
+ if num_incoming_tokens % self.zigzag_sp_degree != 0:
1449
+ pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree
1450
+ num_incoming_tokens += pad_len
1451
+ # pad `input_ids`
1452
+ pad_tensor = torch.full(
1453
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
1454
+ )
1455
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
1456
+
1457
+ # pad `label`
1458
+ pad_label_tensor = torch.full(
1459
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
1460
+ )
1461
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
1462
+ else:
1463
+ raise ValueError(f"Invalid ring_type: {self.ring_type}")
1464
+
1465
+ if num_incoming_tokens > max_seq_length:
1466
+ print(
1467
+ f"Warning: Skipping one packed sample with {num_incoming_tokens} tokens,\
1468
+ please consider increase max seq len {max_seq_length}."
1469
+ )
1470
+ i += 1
1471
+ continue
1472
+
1473
+ if (
1474
+ (current_num_images == 0)
1475
+ or (current_num_images < self.sp_degree)
1476
+ or (
1477
+ (current_num_images + num_images <= max_num_images)
1478
+ and (current_len + num_incoming_tokens <= max_sample_len)
1479
+ )
1480
+ ) and (current_len + num_incoming_tokens <= max_seq_length):
1481
+ current_num_images += num_images
1482
+ current_len += num_incoming_tokens
1483
+ current_num_samples += 1
1484
+ current_position_ids = torch.cat(
1485
+ (current_position_ids, torch.arange(start=0, end=num_incoming_tokens)), dim=0
1486
+ )
1487
+ current_batch = torch.cat((current_batch, sorted_ids[i]), dim=0)
1488
+ sorted_labels[i][0] = IGNORE_INDEX
1489
+ current_label_batch = torch.cat((current_label_batch, sorted_labels[i]), dim=0)
1490
+ seqlens_in_batch.append(num_incoming_tokens)
1491
+ current_batch_images.extend(sorted_images[i])
1492
+ i += 1
1493
+ assert current_num_images == len(current_batch_images)
1494
+ else:
1495
+ break
1496
+
1497
+ # Padding the sample with the dummy image sample, if there are no enough images
1498
+ MAX_RETRY = self.sp_degree
1499
+ num_retry = 0
1500
+ while current_num_images < self.sp_degree and current_len < max_seq_length and num_retry <= MAX_RETRY:
1501
+ current_num_images += dummy_image.size(0)
1502
+ current_len += dummy_seqlen
1503
+ current_num_samples += 1
1504
+ current_position_ids = torch.cat((current_position_ids, dummy_position_ids), dim=0)
1505
+ current_batch = torch.cat((current_batch, dummy_input_ids), dim=0)
1506
+ current_label_batch = torch.cat((current_label_batch, dummy_labels), dim=0)
1507
+ seqlens_in_batch.append(dummy_seqlen)
1508
+ current_batch_images.extend(dummy_image)
1509
+ # We pad from left side to ensure correct grad flow
1510
+ # current_batch = torch.cat((dummy_input_ids, current_batch), dim=0)
1511
+ # current_label_batch = torch.cat((dummy_labels, current_label_batch), dim=0)
1512
+ # seqlens_in_batch.insert(0, dummy_seqlen)
1513
+ # current_batch_images = torch.cat((dummy_image, current_batch_images), dim=0)
1514
+ num_retry += 1
1515
+
1516
+ # Drop the samples that do not have enough images
1517
+ if current_num_images < self.sp_degree:
1518
+ print(f"Warning: Skipping one packed sample with {current_num_images} images")
1519
+ seqlens_in_batch = seqlens_in_batch[:-current_num_samples]
1520
+ continue
1521
+
1522
+ max_sample_len = max(max_sample_len, current_len)
1523
+ batches.append(current_batch)
1524
+ label_batches.append(current_label_batch)
1525
+ position_ids.append(current_position_ids)
1526
+ batch_images.append(current_batch_images)
1527
+
1528
+ try:
1529
+ assert current_num_images == len(torch.where(current_batch == image_token_id)[0].tolist())
1530
+ except AssertionError:
1531
+ print(f"Error num_images on {self.sp_rank}", current_num_images)
1532
+ print("current_batch", current_batch)
1533
+ print(
1534
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
1535
+ len(torch.where(current_batch == image_token_id)[0].tolist()),
1536
+ )
1537
+ print(f"Error len(current_batch_images) on {self.sp_rank}:", len(current_batch_images))
1538
+ raise AssertionError
1539
+
1540
+ # Split for sequence parallelism
1541
+ for i in range(len(batches)):
1542
+ image_token_indices = torch.where(batches[i] == image_token_id)[0].tolist()
1543
+ image_ids = torch.arange(0, len(image_token_indices), dtype=torch.int32)
1544
+ batches[i] = extract_local_input_ids(
1545
+ batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
1546
+ )
1547
+ label_batches[i] = extract_local_input_ids(
1548
+ label_batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
1549
+ )
1550
+ batch_images[i] = torch.concat(
1551
+ extract_local_from_list(batch_images[i], self.sp_rank, self.sp_degree), dim=0
1552
+ )
1553
+ H, W = batch_images[i].size(-2), batch_images[i].size(-1)
1554
+ batch_images[i] = batch_images[i].reshape(-1, 3, W, H)
1555
+ num_images = len(batch_images[i])
1556
+
1557
+ try:
1558
+ assert num_images == len(torch.where(batches[i] == image_token_id)[0].tolist())
1559
+ except AssertionError:
1560
+ print(f"Error num_images on {self.sp_rank}", num_images)
1561
+ print("batches[i]", batches[i])
1562
+ print(
1563
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
1564
+ len(torch.where(batches[i] == image_token_id)[0].tolist()),
1565
+ )
1566
+ print(f"Error batch_images[i] on {self.sp_rank}:", batch_images[i].shape)
1567
+ raise AssertionError
1568
+ position_ids[i] = extract_local_position_ids(
1569
+ position_ids[i], image_token_indices, image_ids, self.sp_rank, self.sp_degree, NUM_TOKENS_PER_IMAGE - 1
1570
+ )
1571
+
1572
+ input_ids = torch.nn.utils.rnn.pad_sequence(
1573
+ batches, batch_first=True, padding_value=self.tokenizer.pad_token_id
1574
+ )
1575
+ labels = torch.nn.utils.rnn.pad_sequence(label_batches, batch_first=True, padding_value=IGNORE_INDEX)
1576
+ seqlens_in_batch = [torch.tensor(x) for x in seqlens_in_batch]
1577
+ seqlens_in_batch = torch.stack(seqlens_in_batch, axis=0)
1578
+ seqlens_in_batch = seqlens_in_batch.flatten()
1579
+ position_ids = torch.nn.utils.rnn.pad_sequence(position_ids, batch_first=True, padding_value=-1)
1580
+
1581
+ if batch_images:
1582
+ batch_images = [torch.unbind(images) for images in batch_images]
1583
+ flat_batch_images = [item for sublist in batch_images for item in sublist]
1584
+ else:
1585
+ flat_batch_images = None
1586
+ batch = dict(
1587
+ input_ids=input_ids,
1588
+ labels=labels,
1589
+ # notice that we inject attention mask here
1590
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
1591
+ seqlens_in_batch=seqlens_in_batch,
1592
+ media={"image": flat_batch_images},
1593
+ media_config={"image": {}},
1594
+ position_ids=position_ids,
1595
+ )
1596
+ return batch
1597
+
1598
+
1599
+ def make_supervised_data_module(
1600
+ tokenizer: PreTrainedTokenizer,
1601
+ data_args: DataArguments,
1602
+ training_args: TrainingArguments,
1603
+ ) -> Dict:
1604
+ """Make dataset and collator for supervised fine-tuning.
1605
+ This function is originally implemented by the LLaVA team and
1606
+ modified by Jason Lu, Haotian Tang and Ligeng Zhu."""
1607
+ datasets_mixture.register_datasets_mixtures()
1608
+
1609
+ from .builder import build_dataset
1610
+
1611
+ train_dataset = build_dataset(data_args.data_mixture, data_args, training_args, tokenizer)
1612
+ training_args.sample_lens = [len(d) for d in train_dataset.datasets]
1613
+
1614
+ PROCESS_GROUP_MANAGER = get_pg_manager()
1615
+ if PROCESS_GROUP_MANAGER is None:
1616
+ data_collator = DataCollator(tokenizer=tokenizer)
1617
+ else:
1618
+ sp_degree = training_args.seq_parallel_size
1619
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
1620
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
1621
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
1622
+ data_collator = DataCollatorForSupervisedDatasetSeqParallel(
1623
+ tokenizer=tokenizer,
1624
+ data_args=data_args,
1625
+ training_args=training_args,
1626
+ sp_degree=sp_degree,
1627
+ sp_rank=sp_rank,
1628
+ ring_degree=ring_degree,
1629
+ ring_type=ring_type,
1630
+ )
1631
+
1632
+ return dict(
1633
+ train_dataset=train_dataset,
1634
+ data_collator=data_collator,
1635
+ )
llava/data/datasets_mixture.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import warnings
24
+ from dataclasses import dataclass, field
25
+
26
+
27
+ @dataclass
28
+ class Dataset:
29
+ dataset_name: str
30
+ dataset_type: str = field(default="torch")
31
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
32
+ meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."})
33
+ image_path: str = field(default=None, metadata={"help": "Path to the training image data."})
34
+ speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."})
35
+ caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."})
36
+ description: str = field(
37
+ default=None,
38
+ metadata={
39
+ "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset."
40
+ },
41
+ )
42
+ test_script: str = (None,)
43
+ maintainer: str = (None,)
44
+ ############## ############## ############## ############## ############## ##############
45
+ caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
46
+ caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
47
+ start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
48
+ end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
49
+
50
+
51
+ DATASETS_LEGACY = {}
52
+
53
+
54
+ def add_dataset(dataset):
55
+ if dataset.dataset_name in DATASETS_LEGACY:
56
+ # make sure the data_name is unique
57
+ warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.")
58
+ assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'."
59
+ DATASETS_LEGACY.update({dataset.dataset_name: dataset})
60
+
61
+
62
+ def register_datasets_mixtures():
63
+ ############## ############## ############## ############## ############## ##############
64
+ # Audio Datasets
65
+ ############## ############## ############## ############## ############## ##############
66
+
67
+ data_mixture_1 = Dataset(
68
+ dataset_name="data_mixture_1",
69
+ dataset_type="torch",
70
+ data_path="/path/to/your/data_mixture_1/train.json",
71
+ )
72
+ add_dataset(data_mixture_1)
73
+
74
+ data_mixture_2 = Dataset(
75
+ dataset_name="data_mixture_2",
76
+ dataset_type="torch",
77
+ data_path="/path/to/your/data_mixture_2/train.json",
78
+ )
79
+ add_dataset(data_mixture_2)
80
+ # Add more data mixtures below
llava/data/registry/datasets/audio_test.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ Clotho-AQA-AQA:
3
+ _target_: llava.data.LLaVADataset
4
+ data_path: Clotho-AQA-AQA/test.json
5
+ Music-AVQA-AQA_All:
6
+ _target_: llava.data.LLaVADataset
7
+ data_path: Music-AVQA-AQA_All/test.json
8
+ CochlScene-SceneClassification:
9
+ _target_: llava.data.LLaVADataset
10
+ data_path: CochlScene-SceneClassification/test.json
11
+ NSynth-Source:
12
+ _target_: llava.data.LLaVADataset
13
+ data_path: NSynth-Source/test.json
14
+ NSynth-Instrument:
15
+ _target_: llava.data.LLaVADataset
16
+ data_path: NSynth-Instrument/test.json
17
+ FSD50k-EventClassification:
18
+ _target_: llava.data.LLaVADataset
19
+ data_path: FSD50k-EventClassification/test.json
20
+ Clotho-v2-AudioCaptioning:
21
+ _target_: llava.data.LLaVADataset
22
+ data_path: Clotho-v2-AudioCaptioning/test.json
23
+ audiocaps-AudioCaptioning:
24
+ _target_: llava.data.LLaVADataset
25
+ data_path: audiocaps-AudioCaptioning/test.json
26
+ ravdess-EmotionClassification:
27
+ _target_: llava.data.LLaVADataset
28
+ data_path: ravdess-EmotionClassification/val.json
29
+ GTZAN-GenreClassification:
30
+ _target_: llava.data.LLaVADataset
31
+ data_path: GTZAN-GenreClassification/test.json
32
+ UrbanSound8K-EventClassification:
33
+ _target_: llava.data.LLaVADataset
34
+ data_path: UrbanSound8K-EventClassification/train.json
35
+ Medley-solos-DB-InstrClassification:
36
+ _target_: llava.data.LLaVADataset
37
+ data_path: Medley-solos-DB-InstrClassification/test.json
38
+ ESC50-EventClassification:
39
+ _target_: llava.data.LLaVADataset
40
+ data_path: ESC50-EventClassification/train.json
41
+ CREMA-D-EmotionClassification:
42
+ _target_: llava.data.LLaVADataset
43
+ data_path: CREMA-D-EmotionClassification/test.json
44
+ IEMOCAP-EmotionClassification:
45
+ _target_: llava.data.LLaVADataset
46
+ data_path: IEMOCAP-EmotionClassification/test.json
47
+ MELD-EmotionClassification:
48
+ _target_: llava.data.LLaVADataset
49
+ data_path: MELD-EmotionClassification/test.json
50
+ MELD-SentimentClassification:
51
+ _target_: llava.data.LLaVADataset
52
+ data_path: MELD-SentimentClassification/test.json
53
+ MMAU:
54
+ _target_: llava.data.LLaVADataset
55
+ data_path: MMAU/test.json
56
+ MMAU-mini:
57
+ _target_: llava.data.LLaVADataset
58
+ data_path: MMAU/test-mini.json
59
+ AudioEntailmentQA:
60
+ _target_: llava.data.LLaVADataset
61
+ data_path: AudioEntailmentQA/test.json
62
+ SPGI-ASR:
63
+ _target_: llava.data.LLaVADataset
64
+ data_path: SPGI-ASR/val.json
65
+ SWBD-ASR:
66
+ _target_: llava.data.LLaVADataset
67
+ data_path: SWBD-ASR/val.json
68
+ LibriSpeech-ASR-clean:
69
+ _target_: llava.data.LLaVADataset
70
+ data_path: LibriSpeech-ASR/test_clean.json
71
+ LibriSpeech-ASR-other:
72
+ _target_: llava.data.LLaVADataset
73
+ data_path: LibriSpeech-ASR/test_other.json
74
+ VoxPopuli-ASR:
75
+ _target_: llava.data.LLaVADataset
76
+ data_path: VoxPopuli-ASR/test.json
77
+ Europarl-ASR:
78
+ _target_: llava.data.LLaVADataset
79
+ data_path: Europarl-ASR/test.json
80
+ CV-ASR:
81
+ _target_: llava.data.LLaVADataset
82
+ data_path: CV-ASR/test.json
83
+ GigaSpeech-ASR:
84
+ _target_: llava.data.LLaVADataset
85
+ data_path: GigaSpeech-ASR/test.json
86
+ CompA-R-AQA:
87
+ _target_: llava.data.LLaVADataset
88
+ data_path: CompA-R-AQA/test.json
89
+ MuschoMusicQA:
90
+ _target_: llava.data.LLaVADataset
91
+ data_path: MuschoMusicQA/test.json
92
+ CMM:
93
+ _target_: llava.data.LLaVADataset
94
+ data_path: CMM/test.json
95
+ AIR-Bench:
96
+ _target_: llava.data.LLaVADataset
97
+ data_path: AIR-Bench/test.json
llava/data/registry/datasets/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ---
2
+ dummy:
3
+ _target_: llava.data.DummyDataset
4
+ num_instances: 10000
5
+ comments: dummy dataset for testing
llava/data/registry/mixtures.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ audio_speech_all:
3
+ -CV-ASR_1
4
+ -MELD-EmotionClassification+
5
+ -BBCSoundEffects-AudioDescription
6
+ -SWBD-ASR_1
7
+ -WavCaps-SoundBible-AudioCaptioning
8
+ -AudioSet-Speech-Audio-QA
9
+ -SONYC-UST-EventClassification
10
+ -VoxPopuli-ASR_1
11
+ -FSD50k-EventClassification
12
+ -SalmonnQA
13
+ -emov-db-EmotionClassification
14
+ -LLARK_MagnaTagATune-mir+tess-EmotionClassification
15
+ -Europarl-ASR_1
16
+ -jl-corpus-EmotionClassification
17
+ -Ego-10-AudioCaptioning
18
+ -SPGI-ASR_1
19
+ -CREMA-D-EmotionClassification
20
+ -MusicBenchQA
21
+ -WavCaps-BBC_Sound_Effects-AudioCaptioning
22
+ -NSynth-Instrument
23
+ -SpokenSquadQA
24
+ -NSynth-MIR
25
+ -AudioEntailmentQA
26
+ -GigaSpeech-ASR_1
27
+ -WavCaps-AudioSet_SL-AudioCaptioning
28
+ -NonSpeech7k-EventClassification
29
+ -chime-home-EventClassification
30
+ -MusicCaps-AudioCaptioning
31
+ -LP-MusicCaps-MSD-AudioCaptioning
32
+ -Ego-30-AudioCaptioning
33
+ -NSynth-Source+Clotho-v2-AudioCaptioning
34
+ -LP-MusicCaps-MC-AudioCaptioning
35
+ -Clotho-AQA-EventClassification
36
+ -WavCaps-FreeSound-AudioCaptioning
37
+ -LLARK_MagnaTagATune-reasoning
38
+ -AudioSet-Temporal-Speech-Audio-QA
39
+ -TUT-EventClassification
40
+ -ESC50-EventClassification
41
+ -WavText5K-Tagging
42
+ -MELD-SentimentClassification
43
+ -Music-AVQA-AQA_All
44
+ -Music-AVQA-AVQA_All
45
+ -MACS-AudioCaptioning
46
+ -Medley-solos-DB-InstrClassification
47
+ -AudioSet-EventClassification
48
+ -OMGEmotion-EmotionClassification
49
+ -FMA-GenreClassification
50
+ -Epidemic_sound-AudioCaptioning
51
+ -CochlScene-SceneClassification
52
+ -LLARK_FMA-reasoning
53
+ -ravdess-EmotionClassification
54
+ -CompA-R-AQA
55
+ -MU-LLAMA-AQA
56
+ -musdbhq-InstrClassification
57
+ -UrbanSound8K-EventClassification
58
+ -audiocaps-AudioCaptioning
59
+ -VocalSound-VocalClassification
60
+ -CLAP_freesound-AudioCaptioning
61
+ -MMAUQA
62
+ -SongDescriber-AudioCaptioning
63
+ -HeySQuADQA
64
+ -Mira-AudioCaptioning
65
+ -Clotho-AQA-AQA
66
+ -LibriSpeech-ASR_1
67
+ -IEMOCAP-EmotionClassification
68
+ -AudioSetFullwoAudioMusicCaps-EventClassification
69
+ -MSP-PODCAST-Publish-1.9-EmotionClassification
70
+ -OpenAQA-AQA
71
+ -SoundDescs-AudioDescription
72
+ -LibriSQA
73
+ -LLARK_FMA-mir
74
+ -LP-MusicCaps-MTT-AudioCaptioning
75
+ -GTZAN-GenreClassification
76
+ -musdbhq-captioning
77
+ -YesNoQA
78
+
llava/entry.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import os
24
+ import typing
25
+ from typing import List, Optional
26
+
27
+ if typing.TYPE_CHECKING:
28
+ from transformers import PreTrainedModel
29
+ else:
30
+ PreTrainedModel = None
31
+
32
+ __all__ = ["load"]
33
+
34
+
35
+ def load(
36
+ model_path: str,
37
+ model_base: Optional[str] = None,
38
+ devices: Optional[List[int]] = None,
39
+ **kwargs,
40
+ ) -> PreTrainedModel:
41
+ import torch
42
+
43
+ from llava.conversation import auto_set_conversation_mode
44
+ from llava.mm_utils import get_model_name_from_path
45
+ from llava.model.builder import load_pretrained_model
46
+
47
+ auto_set_conversation_mode(model_path)
48
+
49
+ model_name = get_model_name_from_path(model_path)
50
+ model_path = os.path.expanduser(model_path)
51
+ if os.path.exists(os.path.join(model_path, "model")):
52
+ model_path = os.path.join(model_path, "model")
53
+
54
+ # Set `max_memory` to constrain which GPUs to use
55
+ if devices is not None:
56
+ assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set"
57
+ kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices})
58
+
59
+ model = load_pretrained_model(model_path, model_name, model_base, **kwargs)[1]
60
+ return model
llava/eval/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+
9
+ from llava.utils import io
10
+
11
+ __all__ = ["EVAL_ROOT", "TASKS"]
12
+
13
+
14
+ EVAL_ROOT = "scripts/eval"
15
+ TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry_audio.yaml"))
llava/eval/eval_audio_bench.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import argparse
8
+ import csv
9
+ import itertools
10
+ import json
11
+ import os
12
+
13
+ import torch
14
+ from datasets import load_dataset
15
+ from tqdm import tqdm
16
+
17
+ import llava
18
+ from llava import conversation as conversation_lib
19
+ from llava.data.builder import DATASETS
20
+ from llava.eval.mmmu_utils.eval_utils import parse_choice
21
+ from llava.utils import distributed as dist
22
+ from llava.utils import io
23
+ from llava.utils.logging import logger
24
+
25
+
26
+ def load_existing_ids(output_file):
27
+ if not os.path.exists(output_file):
28
+ return set(), []
29
+ try:
30
+ with open(output_file, "r") as f:
31
+ lines = f.readlines()
32
+ outputs = [json.loads(line) for line in lines]
33
+ processed_ids = {item["id"] for item in outputs}
34
+ return processed_ids, outputs
35
+ except Exception as e:
36
+ print(f"Error loading existing outputs: {e}")
37
+ return set(), []
38
+
39
+
40
+ def main() -> None:
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--model-path", type=str, default=None)
43
+ parser.add_argument("--model-base", type=str, default=None)
44
+ parser.add_argument("--task", type=str, default=None)
45
+ parser.add_argument("--conv-mode", type=str, default="auto")
46
+ parser.add_argument("--generation-config", type=json.loads)
47
+ parser.add_argument("--output-dir", type=str, default=None)
48
+ args = parser.parse_args()
49
+
50
+ # Set up distributed environment
51
+ dist.init()
52
+ devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
53
+ torch.cuda.set_device(devices[0])
54
+
55
+ # Load stage 3 model with line 56
56
+ model = llava.load(args.model_base, model_base=None, devices=devices)
57
+ # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode
58
+ # model = PeftModel.from_pretrained(
59
+ # model,
60
+ # args.model_path,
61
+ # device_map="auto",
62
+ # torch_dtype=torch.float16,
63
+ # )
64
+ # Set up generation config
65
+ generation_config = model.default_generation_config
66
+ if args.generation_config is not None:
67
+ generation_config.update(**args.generation_config)
68
+
69
+ # Load data and chunk it
70
+ json_file = DATASETS[args.task]["data_path"]
71
+ instances = io.load(json_file)
72
+ instances = instances[dist.rank() :: dist.size()]
73
+
74
+ output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl")
75
+ processed_ids, outputs = load_existing_ids(output_path)
76
+
77
+ count = len(outputs)
78
+ # Run inference
79
+ new_outputs = []
80
+ for instance in tqdm(instances, disable=not dist.is_main()):
81
+ uuid = instance["id"]
82
+ sound_path = instance["sound"]
83
+
84
+ if sound_path in processed_ids:
85
+ continue # Skip if already processed
86
+ sound = llava.Sound(sound_path)
87
+ conversations = instance["conversations"]
88
+ question = conversations[0]["value"]
89
+
90
+ response = model.generate_content([sound, question], generation_config=generation_config)
91
+
92
+ print("response", response)
93
+
94
+ output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response}
95
+ new_outputs.append(output)
96
+ count = count +1
97
+ if count % 20 == 0:
98
+ # Gather and save outputs
99
+ if dist.size() > 1:
100
+ outputs_new = dist.gather(new_outputs, dst=0)
101
+ if dist.is_main():
102
+ outputs_new = list(itertools.chain(*outputs_new))
103
+ final_outputs = outputs + outputs_new
104
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
105
+ else:
106
+ final_outputs = outputs + new_outputs
107
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
108
+ if dist.size() > 1:
109
+ new_outputs = dist.gather(new_outputs, dst=0)
110
+ if not dist.is_main():
111
+ return
112
+ new_outputs = list(itertools.chain(*new_outputs))
113
+ final_outputs = outputs + new_outputs
114
+ io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs)
115
+
116
+ if __name__ == "__main__":
117
+ main()
llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
llava/eval/mmmu_utils/eval_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is originated from the official MMMU codebase:
2
+ # https://github.com/MMMU-Benchmark/MMMU
3
+
4
+ import random
5
+
6
+ import numpy as np
7
+
8
+
9
+ def parse_choice(response, all_choices, index2ans=None):
10
+ """
11
+ Parse the prediction from the generated response.
12
+ Return the predicted index e.g., A, B, C, D.
13
+ """
14
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
15
+ response = response.strip(char)
16
+ response = " " + response + " " # add space to avoid partial match
17
+
18
+ index_ans = True
19
+ ans_with_brack = False
20
+ candidates = []
21
+ for choice in all_choices: # e.g., (A) (B) (C) (D)
22
+ if f"({choice})" in response:
23
+ candidates.append(choice)
24
+ ans_with_brack = True
25
+
26
+ if len(candidates) == 0:
27
+ for choice in all_choices: # e.g., A B C D
28
+ if f" {choice} " in response:
29
+ candidates.append(choice)
30
+
31
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
32
+ if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None:
33
+ for index, ans in index2ans.items():
34
+ if ans.lower() in response.lower():
35
+ candidates.append(index)
36
+ index_ans = False # it's content ans.
37
+
38
+ if len(candidates) == 0: # still not get answer, randomly choose one.
39
+ pred_index = random.choice(all_choices)
40
+ elif len(candidates) > 1:
41
+ start_indexes = []
42
+ if index_ans:
43
+ if ans_with_brack:
44
+ for can in candidates:
45
+ index = response.rfind(f"({can})")
46
+ start_indexes.append(index) # -1 will be ignored anyway
47
+ # start_indexes = [generated_response.index(f'({can})') for can in candidates]
48
+ else:
49
+ for can in candidates:
50
+ index = response.rfind(f" {can} ")
51
+ start_indexes.append(index)
52
+ else:
53
+ for can in candidates:
54
+ index = response.lower().rfind(index2ans[can].lower())
55
+ start_indexes.append(index)
56
+ # get the last one
57
+ pred_index = candidates[np.argmax(start_indexes)]
58
+ else: # if only one candidate, use it.
59
+ pred_index = candidates[0]
60
+
61
+ return pred_index
llava/eval/registry_audio.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Clotho-AQA-AQA:
2
+ tags:
3
+ - local
4
+ Music-AVQA-AQA_All:
5
+ tags:
6
+ - local
7
+ CochlScene-SceneClassification:
8
+ tags:
9
+ - local
10
+ NSynth-Source:
11
+ tags:
12
+ - local
13
+ NSynth-Instrument:
14
+ tags:
15
+ - local
16
+ FSD50k-EventClassification:
17
+ tags:
18
+ - local
19
+ Clotho-v2-AudioCaptioning:
20
+ tags:
21
+ - local
22
+ audiocaps-AudioCaptioning:
23
+ tags:
24
+ - local
25
+ ravdess-EmotionClassification:
26
+ tags:
27
+ - local
28
+ GTZAN-GenreClassification:
29
+ tags:
30
+ - local
31
+ UrbanSound8K-EventClassification:
32
+ tags:
33
+ - local
34
+ Medley-solos-DB-InstrClassification:
35
+ tags:
36
+ - local
37
+ ESC50-EventClassification:
38
+ tags:
39
+ - local
40
+ CREMA-D-EmotionClassification:
41
+ tags:
42
+ - local
43
+ IEMOCAP-EmotionClassification:
44
+ tags:
45
+ - local
46
+ MELD-EmotionClassification:
47
+ tags:
48
+ - local
49
+ MELD-SentimentClassification:
50
+ tags:
51
+ - local
52
+ MMAU:
53
+ tags:
54
+ - local
55
+ AudioEntailmentQA:
56
+ tags:
57
+ - local
58
+ SPGI-ASR:
59
+ tags:
60
+ - local
61
+ SWBD-ASR:
62
+ tags:
63
+ - local
64
+ LibriSpeech-ASR-clean:
65
+ tags:
66
+ - local
67
+ LibriSpeech-ASR-other:
68
+ tags:
69
+ - local
70
+ VoxPopuli-ASR:
71
+ tags:
72
+ - local
73
+ Europarl-ASR:
74
+ tags:
75
+ - local
76
+ CV-ASR:
77
+ tags:
78
+ - local
79
+ GigaSpeech-ASR:
80
+ tags:
81
+ - local
82
+ CompA-R-AQA:
83
+ tags:
84
+ - local
85
+ MuschoMusicQA:
86
+ tags:
87
+ - local
88
+ CMM:
89
+ tags:
90
+ - local
91
+ AIR-Bench:
92
+ tags:
93
+ - local
llava/media.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ __all__ = ["Media", "File", "Image", "Video", "Speech", "Sound"]
24
+
25
+
26
+ class Media:
27
+ pass
28
+
29
+
30
+ class File(Media):
31
+ def __init__(self, path: str) -> None:
32
+ self.path = path
33
+
34
+
35
+ class Image(File):
36
+ pass
37
+
38
+
39
+ class Video(File):
40
+ pass
41
+
42
+
43
+ class Speech(File):
44
+ pass
45
+
46
+ class Sound(File):
47
+ pass
llava/mm_utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ # dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
24
+
25
+ import base64
26
+ import os
27
+ import tempfile
28
+ from io import BytesIO
29
+
30
+ import numpy as np
31
+ import torch
32
+ from PIL import Image
33
+ from transformers import StoppingCriteria
34
+
35
+ from pydub import AudioSegment
36
+ from torchvision import transforms
37
+ import soundfile as sf
38
+ from librosa import resample as librosa_resample
39
+ import whisper
40
+ import random
41
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
42
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
43
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
44
+ import cv2
45
+
46
+ if fps == None or frame_count == None:
47
+ # if one of fps or frame_count is None, still recompute
48
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
49
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
50
+ if fps == 0 or frame_count == 0:
51
+ print(f"Video file not found. return empty images. {video_file_name}")
52
+ return [
53
+ Image.new("RGB", (720, 720)),
54
+ ] * num_frames, 0, [0.]
55
+
56
+ duration = frame_count / fps
57
+ frame_interval = frame_count // num_frames
58
+ if frame_interval == 0 and frame_count <= 1:
59
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
60
+ return [
61
+ Image.new("RGB", (720, 720)),
62
+ ] * num_frames, 0, [0.]
63
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
64
+
65
+ images = []
66
+ count = 0
67
+ success = True
68
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
69
+ frame_times = [frame / fps for frame in frame_indices]
70
+ while success:
71
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
72
+ if frame_count >= num_frames:
73
+ success, frame = vidcap.read()
74
+ if count in frame_indices:
75
+ try:
76
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
77
+ im_pil = Image.fromarray(img)
78
+ images.append(im_pil)
79
+ except BaseException:
80
+ continue
81
+ if len(images) >= num_frames:
82
+ return images, num_frames, frame_times
83
+ count += 1
84
+ else:
85
+ # Left padding frames if the video is not long enough
86
+ success, frame = vidcap.read()
87
+ if success:
88
+ try:
89
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90
+ im_pil = Image.fromarray(img)
91
+ images.append(im_pil)
92
+ except BaseException:
93
+ continue
94
+ count += 1
95
+ else:
96
+ break
97
+ if len(images) == 0:
98
+ raise ValueError("Did not find enough frames in the video. return empty image.")
99
+
100
+ return images, len(images), frame_times
101
+
102
+
103
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
104
+ """
105
+ num_frames is the max number of frames the model can support.
106
+ frame_count is the number of frames in the input video.
107
+ max_fps is the max FPS of the model can support.
108
+ fps is the fps of the input video.
109
+ """
110
+
111
+ import random
112
+
113
+ import cv2
114
+
115
+ if fps == None or frame_count == None:
116
+ # if one of fps or frame_count is None, still recompute
117
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
118
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
119
+
120
+ if fps == 0 or frame_count == 0:
121
+ print(f"Video file not found. return empty images. {video_file_name}")
122
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
123
+ return [
124
+ Image.new("RGB", (720, 720)),
125
+ ] * empty_video_frames, 0, [0.]
126
+
127
+ duration = frame_count / fps
128
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
129
+ # If the video is too long (longer than max_fps and num_frames can support),
130
+ # we will use lower fps to sample frames.
131
+ if duration >= num_frames / max_fps:
132
+ frame_interval = frame_count // num_frames
133
+
134
+ # If the video is too short, we will skip the video if there is only one frame.
135
+ if frame_interval == 0 and frame_count <= 1:
136
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
137
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
138
+ return [
139
+ Image.new("RGB", (720, 720)),
140
+ ] * empty_video_frames, 0, [0.]
141
+
142
+ images = []
143
+ count = 0
144
+ success = True
145
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
146
+ frame_times = [frame / fps for frame in frame_indices]
147
+ while success:
148
+ if frame_count >= num_frames:
149
+ # success, frame = vidcap.read()
150
+ if count in frame_indices:
151
+ success, frame = vidcap.read()
152
+ try:
153
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
+ im_pil = Image.fromarray(img)
155
+ images.append(im_pil)
156
+ except:
157
+ # print("Failed to read frame:", count)
158
+ continue
159
+ if len(images) >= num_frames:
160
+ return images, num_frames, frame_times
161
+ else:
162
+ success = vidcap.grab()
163
+ count += 1
164
+ else:
165
+ # Left padding frames if the video is not long enough
166
+ success, frame = vidcap.read()
167
+ if success:
168
+ try:
169
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
170
+ im_pil = Image.fromarray(img)
171
+ images.append(im_pil)
172
+ except:
173
+ # print("Failed to read frame:", count)
174
+ continue
175
+ count += 1
176
+ else:
177
+ break
178
+ else:
179
+ frames_required = int(duration * max_fps)
180
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
181
+ if frames_required == 0:
182
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
183
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
184
+ return [
185
+ Image.new("RGB", (720, 720)),
186
+ ] * empty_video_frames, 0, [0.]
187
+ elif frames_required == 1:
188
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
189
+ images = []
190
+ count = 0
191
+ looked = 0
192
+ success = True
193
+
194
+ while success:
195
+ success, frame = vidcap.read()
196
+ if success and (looked in frame_indices):
197
+ try:
198
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
199
+ im_pil = Image.fromarray(img)
200
+ images.append(im_pil)
201
+ except:
202
+ continue
203
+ count += 1
204
+ looked += 1
205
+ frame_times = [frame / fps for frame in frame_indices]
206
+
207
+ if len(images) == 0:
208
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
209
+ return [
210
+ Image.new("RGB", (720, 720)),
211
+ ] * empty_video_frames, 0, [0.]
212
+ else:
213
+ return images, len(images), frame_times
214
+
215
+
216
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
217
+ """
218
+ Extract frames from a video using OpenCV.
219
+
220
+ Args:
221
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
222
+ frames (int): Number of frames to extract from the video.
223
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
224
+
225
+ Returns:
226
+ list: List of PIL Images extracted from the video.
227
+
228
+ Raises:
229
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
230
+ """
231
+ import cv2
232
+ if isinstance(vpath_or_bytesio, str):
233
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
234
+ if max_fps > 0.0:
235
+ return get_frame_from_vcap_with_fps(
236
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
237
+ )
238
+ return get_frame_from_vcap(
239
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
240
+ )
241
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
242
+ # assuming mp4
243
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
244
+ temp_video.write(vpath_or_bytesio.read())
245
+ temp_video_name = temp_video.name
246
+ vidcap = cv2.VideoCapture(temp_video_name)
247
+ if max_fps > 0.0:
248
+ return get_frame_from_vcap_with_fps(
249
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
250
+ )
251
+ return get_frame_from_vcap(
252
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
253
+ )
254
+ else:
255
+ raise NotImplementedError(type(vpath_or_bytesio))
256
+
257
+
258
+ def load_image_from_base64(image):
259
+ return Image.open(BytesIO(base64.b64decode(image)))
260
+
261
+
262
+ def expand2square(pil_img, background_color):
263
+ """
264
+ Expand the given PIL image to a square shape by adding padding.
265
+
266
+ Parameters:
267
+ - pil_img: The PIL image to be expanded.
268
+ - background_color: The color of the padding to be added.
269
+
270
+ Returns:
271
+ - The expanded PIL image.
272
+
273
+ If the image is already square, it is returned as is.
274
+ If the image is wider than it is tall, padding is added to the top and bottom.
275
+ If the image is taller than it is wide, padding is added to the left and right.
276
+ """
277
+ width, height = pil_img.size
278
+ if pil_img.mode == "L":
279
+ background_color = background_color[0]
280
+ if width == height:
281
+ return pil_img
282
+ elif width > height:
283
+ result = Image.new(pil_img.mode, (width, width), background_color)
284
+ result.paste(pil_img, (0, (width - height) // 2))
285
+ return result
286
+ else:
287
+ result = Image.new(pil_img.mode, (height, height), background_color)
288
+ result.paste(pil_img, ((height - width) // 2, 0))
289
+ return result
290
+
291
+
292
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
293
+ best_ratio_diff = float("inf")
294
+ best_ratio = (1, 1)
295
+ area = width * height
296
+ for ratio in target_ratios:
297
+ target_aspect_ratio = ratio[0] / ratio[1]
298
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
299
+ if ratio_diff < best_ratio_diff:
300
+ best_ratio_diff = ratio_diff
301
+ best_ratio = ratio
302
+ elif ratio_diff == best_ratio_diff:
303
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
304
+ best_ratio = ratio
305
+ return best_ratio
306
+
307
+
308
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
309
+ orig_width, orig_height = image.size
310
+ aspect_ratio = orig_width / orig_height
311
+
312
+ # calculate the existing image aspect ratio
313
+ target_ratios = {
314
+ (i, j)
315
+ for n in range(min_num, max_num + 1)
316
+ for i in range(1, n + 1)
317
+ for j in range(1, n + 1)
318
+ if i * j <= max_num and i * j >= min_num
319
+ }
320
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
321
+
322
+ # find the closest aspect ratio to the target
323
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
324
+
325
+ # calculate the target width and height
326
+ target_width = image_size * target_aspect_ratio[0]
327
+ target_height = image_size * target_aspect_ratio[1]
328
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
329
+
330
+ # resize the image
331
+ resized_img = image.resize((target_width, target_height))
332
+ processed_images = []
333
+ for i in range(blocks):
334
+ box = (
335
+ (i % (target_width // image_size)) * image_size,
336
+ (i // (target_width // image_size)) * image_size,
337
+ ((i % (target_width // image_size)) + 1) * image_size,
338
+ ((i // (target_width // image_size)) + 1) * image_size,
339
+ )
340
+ # split the image
341
+ split_img = resized_img.crop(box)
342
+ processed_images.append(split_img)
343
+ assert len(processed_images) == blocks
344
+ if use_thumbnail and len(processed_images) != 1:
345
+ thumbnail_img = image.resize((image_size, image_size))
346
+ processed_images.append(thumbnail_img)
347
+ return processed_images
348
+
349
+
350
+ def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
351
+ orig_width, orig_height = image.size
352
+ aspect_ratio = orig_width / orig_height
353
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
354
+
355
+ processed_images = []
356
+
357
+ ##########################################################################################
358
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
359
+ ##########################################################################################
360
+
361
+ for scale in s2_scales[:-1]:
362
+ target_width = image_size * (scale // s2_scales[0])
363
+ target_height = image_size * (scale // s2_scales[0])
364
+ blocks = (scale // s2_scales[0]) ** 2
365
+
366
+ # resize the image
367
+ resized_img = image.resize((target_width, target_height))
368
+ for i in range(blocks):
369
+ box = (
370
+ (i % (target_width // image_size)) * image_size,
371
+ (i // (target_width // image_size)) * image_size,
372
+ ((i % (target_width // image_size)) + 1) * image_size,
373
+ ((i // (target_width // image_size)) + 1) * image_size,
374
+ )
375
+ # split the image
376
+ split_img = resized_img.crop(box)
377
+ processed_images.append(split_img)
378
+
379
+ ##########################################################################################
380
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
381
+ ##########################################################################################
382
+
383
+ # calculate the existing image aspect ratio
384
+ target_ratios = {
385
+ (i, j)
386
+ for n in range(min_num, max_num + 1)
387
+ for i in range(1, n + 1)
388
+ for j in range(1, n + 1)
389
+ if i * j <= max_num and i * j >= min_num
390
+ }
391
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
392
+
393
+ # find the closest aspect ratio to the target
394
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
395
+
396
+ # calculate the target width and height
397
+ target_width = image_size * target_aspect_ratio[0]
398
+ target_height = image_size * target_aspect_ratio[1]
399
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
400
+
401
+ # resize the image
402
+ resized_img = image.resize((target_width, target_height))
403
+ for i in range(blocks):
404
+ box = (
405
+ (i % (target_width // image_size)) * image_size,
406
+ (i // (target_width // image_size)) * image_size,
407
+ ((i % (target_width // image_size)) + 1) * image_size,
408
+ ((i // (target_width // image_size)) + 1) * image_size,
409
+ )
410
+ # split the image
411
+ split_img = resized_img.crop(box)
412
+ processed_images.append(split_img)
413
+
414
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
415
+
416
+
417
+
418
+ def dynamic_s2_process_images_and_prompt(images, data_args, image_folder=None):
419
+ idx = 0
420
+ all_images = []
421
+ all_block_size = []
422
+ for img in images:
423
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
424
+ all_images.append(processed_images)
425
+ all_block_size.append(block_size)
426
+ idx += 2
427
+ if all_images:
428
+ all_images = torch.cat(all_images)
429
+ else:
430
+ all_images = None
431
+ return all_images, all_block_size
432
+
433
+
434
+ def process_image(
435
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
436
+ ):
437
+ processor = data_args.image_processor
438
+ if isinstance(image_file, str):
439
+ if image_folder is not None:
440
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
441
+ else:
442
+ image = Image.open(image_file).convert("RGB")
443
+ else:
444
+ # image is stored in bytearray
445
+ image = image_file
446
+ image = image.convert("RGB")
447
+ if hasattr(data_args.image_processor, "crop_size"):
448
+ # CLIP vision tower
449
+ crop_size = data_args.image_processor.crop_size
450
+ else:
451
+ # SIGLIP vision tower
452
+ assert hasattr(data_args.image_processor, "size")
453
+ crop_size = data_args.image_processor.size
454
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
455
+ assert crop_size["height"] == crop_size["width"]
456
+ images, block_size = dynamic_s2_preprocess(
457
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
458
+ )
459
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
460
+ return torch.stack(images), block_size
461
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
462
+ assert crop_size["height"] == crop_size["width"]
463
+ if max_tiles is not None:
464
+ max_num = max_tiles
465
+ else:
466
+ max_num = data_args.max_tiles
467
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
468
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
469
+ return torch.stack(images)
470
+
471
+ if data_args.image_aspect_ratio == "resize":
472
+ image = image.resize((crop_size["width"], crop_size["height"]))
473
+ if data_args.image_aspect_ratio == "pad":
474
+
475
+ def expand2square(pil_img, background_color):
476
+ width, height = pil_img.size
477
+ if width == height:
478
+ return pil_img
479
+ elif width > height:
480
+ result = Image.new(pil_img.mode, (width, width), background_color)
481
+ result.paste(pil_img, (0, (width - height) // 2))
482
+ return result
483
+ else:
484
+ result = Image.new(pil_img.mode, (height, height), background_color)
485
+ result.paste(pil_img, ((height - width) // 2, 0))
486
+ return result
487
+
488
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
489
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
490
+ else:
491
+ # Using default behavior of the vision encoder
492
+ # For CLIP, default is central crop
493
+ # For Radio, default is central crop
494
+ # For Siglip, default is resize
495
+ # For InternVIT, default is resize
496
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
497
+ return image
498
+
499
+ def get_num_windows(T, sr, max_num_window=5):
500
+
501
+ window_length = int(30.0 * sr)
502
+ window_overlap = int(0.0 * sr)
503
+ max_num_window = max_num_window
504
+
505
+ num_windows = 1
506
+ if T <= window_length:
507
+ num_windows = 1
508
+ full_length = window_length
509
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
510
+ num_windows = max_num_window
511
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
512
+ else:
513
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
514
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
515
+
516
+ return num_windows, full_length
517
+
518
+ def load_audio(file_path, target_sr=16000, duration=30.0, start=0.0):
519
+ if file_path.endswith('.mp3'):
520
+ audio = AudioSegment.from_file(file_path)
521
+ if len(audio) > (start + duration) * 1000:
522
+ audio = audio[start * 1000:(start + duration) * 1000]
523
+
524
+ if audio.frame_rate != target_sr:
525
+ audio = audio.set_frame_rate(target_sr)
526
+
527
+ if audio.channels > 1:
528
+ audio = audio.set_channels(1)
529
+
530
+ data = np.array(audio.get_array_of_samples())
531
+ if audio.sample_width == 2:
532
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
533
+ elif audio.sample_width == 4:
534
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
535
+ else:
536
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
537
+
538
+ else:
539
+ with sf.SoundFile(file_path) as audio:
540
+ original_sr = audio.samplerate
541
+ channels = audio.channels
542
+
543
+ max_frames = int((start + duration) * original_sr)
544
+
545
+ audio.seek(int(start * original_sr))
546
+ frames_to_read = min(max_frames, len(audio))
547
+ data = audio.read(frames_to_read)
548
+
549
+ if data.max() > 1 or data.min() < -1:
550
+ data = data / max(abs(data.max()), abs(data.min()))
551
+
552
+ if original_sr != target_sr:
553
+ if channels == 1:
554
+ data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
555
+ else:
556
+ data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
557
+ else:
558
+ if channels != 1:
559
+ data = data.T[0]
560
+
561
+ if data.min() >= 0:
562
+ data = 2 * data / abs(data.max()) - 1.0
563
+ else:
564
+ data = data / max(abs(data.max()), abs(data.min()))
565
+
566
+ assert len(data.shape) == 1, data.shape
567
+ return data
568
+
569
+ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
570
+ model_cfg.image_processor = image_processor
571
+ new_images = [
572
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
573
+ for image in images
574
+ ]
575
+
576
+ if all(x.shape == new_images[0].shape for x in new_images):
577
+ if len(new_images[0].shape) == 4:
578
+ new_images = torch.cat(new_images, dim=0)
579
+ elif len(new_images[0].shape) == 3:
580
+ new_images = torch.stack(new_images, dim=0)
581
+ else:
582
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
583
+ else:
584
+ raise ValueError("The shape of images in new_images is different!")
585
+ return new_images
586
+
587
+ def process_sounds(sounds):
588
+ sounds = torch.tensor(sounds)
589
+ return sounds
590
+
591
+ def process_sound_masks(masks):
592
+ masks = torch.tensor(masks[0])
593
+ return masks
594
+
595
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
596
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
597
+
598
+ def is_gemma_tokenizer(tokenizer):
599
+ return "gemma" in tokenizer.__class__.__name__.lower()
600
+
601
+
602
+ def get_model_name_from_path(model_path):
603
+ model_path = model_path.strip("/")
604
+ model_paths = model_path.split("/")
605
+ if model_paths[-1].startswith("checkpoint-"):
606
+ return model_paths[-2] + "_" + model_paths[-1]
607
+ else:
608
+ return model_paths[-1]
609
+
610
+ class KeywordsStoppingCriteria(StoppingCriteria):
611
+ def __init__(self, keywords, tokenizer, input_ids):
612
+ self.keywords = keywords
613
+ self.keyword_ids = []
614
+ self.max_keyword_len = 0
615
+ for keyword in keywords:
616
+ cur_keyword_ids = tokenizer(keyword).input_ids
617
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
618
+ cur_keyword_ids = cur_keyword_ids[1:]
619
+ if len(cur_keyword_ids) > self.max_keyword_len:
620
+ self.max_keyword_len = len(cur_keyword_ids)
621
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
622
+ self.tokenizer = tokenizer
623
+ self.start_len = input_ids.shape[1]
624
+
625
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
626
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
627
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
628
+ for keyword_id in self.keyword_ids:
629
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
630
+ return True
631
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
632
+ for keyword in self.keywords:
633
+ if keyword in outputs:
634
+ return True
635
+ return False
636
+
637
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
638
+ outputs = []
639
+ for i in range(output_ids.shape[0]):
640
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
641
+ return all(outputs)
llava/model/FloatPointQuantizeTorch.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+
12
+ def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
13
+ sign, x_abs = x.sign(), x.abs()
14
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
15
+ expo = torch.floor(torch.log2(x_abs))
16
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
17
+ mant = x_abs / torch.exp2(expo)
18
+
19
+ mant_int = torch.floor(mant)
20
+ mant_frac = mant - mant_int
21
+ mant_frac = mant_frac * Mhigh
22
+ if stochastic:
23
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
24
+ mant_frac.add_(noise)
25
+ mant_frac = torch.round(mant_frac)
26
+
27
+ mant_q = mant_int + mant_frac / Mhigh
28
+ y = sign * (2**expo) * mant_q
29
+ y = y.to(x)
30
+
31
+ return y
32
+
33
+
34
+ def floatExM0_quantize_torch(x, e_bit, stochastic):
35
+ sign, x_abs = x.sign(), x.abs()
36
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
37
+ expo = torch.log2(x_abs)
38
+ if stochastic:
39
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
40
+ expo.add(noise)
41
+ log_bias = math.log2(4 / 3) - 1 / 2
42
+ expo.add(torch.ones_like(expo) * log_bias)
43
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
44
+ expo = torch.round(expo)
45
+
46
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
47
+ y = y.to(x)
48
+
49
+ return y
50
+
51
+
52
+ def Dynamic_quantize_torch(x, bit, stochastic):
53
+ if stochastic:
54
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
55
+ sign, x_abs = x.sign(), x.abs()
56
+ expo = torch.ceil(torch.log10(x_abs))
57
+ expo = torch.clamp(expo, min=2 - bit)
58
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
59
+
60
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
61
+ mant_frac = torch.round(mant_frac)
62
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
63
+ y = sign * (10**expo) * mant_frac / 10
64
+
65
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
66
+ y = y * zero_mask
67
+ y = y.to(x)
68
+ return y
69
+
70
+
71
+ def ZeroDynamic_quantize_torch(x, bit, stochastic):
72
+ if stochastic:
73
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
74
+ sign, x_abs = x.sign(), x.abs()
75
+ expo = torch.ceil(torch.log10(x_abs))
76
+ expo = torch.clamp(expo, min=2 - bit)
77
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
78
+
79
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
80
+ mant_frac = torch.round(mant_frac)
81
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
82
+ y = sign * (10**expo) * mant_frac / 10
83
+
84
+ y = y.to(x)
85
+ return y
llava/model/FloatPointQuantizeTriton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+ import struct
9
+
10
+ import numpy as np
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from triton.language.extra.cuda import libdevice
15
+
16
+ segment_size = 1024**3
17
+
18
+
19
+ def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
20
+ x_ori_shape = x.shape
21
+ x = x.view(-1)
22
+
23
+ n_elements = x.numel()
24
+
25
+ if n_elements <= segment_size:
26
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
27
+ y = torch.empty_like(x)
28
+
29
+ if x.dtype in [torch.bfloat16, torch.float32]:
30
+ if stochastic:
31
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
32
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
33
+ else:
34
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
35
+ torch.cuda.synchronize()
36
+ else:
37
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
38
+ else: # Triton will break when x.numel > 2 * 1024 ** 3
39
+ num_segments = n_elements // segment_size + 1
40
+ split_size = [segment_size] * (num_segments - 1) + [n_elements - segment_size * (num_segments - 1)]
41
+ x_list = x.split(split_size)
42
+ y_list = []
43
+ del x
44
+
45
+ for x in x_list:
46
+ n_elements = x.numel()
47
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
48
+ y = torch.empty_like(x)
49
+
50
+ if x.dtype in [torch.bfloat16, torch.float32]:
51
+ if stochastic:
52
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
53
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
54
+ else:
55
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
56
+ torch.cuda.synchronize()
57
+ else:
58
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
59
+
60
+ y_list.append(y)
61
+ y = torch.concat(y_list)
62
+ del y_list
63
+
64
+ y = y.reshape(x_ori_shape)
65
+ return y
66
+
67
+
68
+ @triton.autotune(
69
+ configs=[
70
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
71
+ triton.Config(
72
+ {
73
+ "BLOCK_SIZE": 1024,
74
+ },
75
+ num_warps=4,
76
+ ),
77
+ triton.Config(
78
+ {
79
+ "BLOCK_SIZE": 2048,
80
+ },
81
+ num_warps=4,
82
+ ),
83
+ ],
84
+ key=["n_elements"],
85
+ )
86
+ @triton.jit
87
+ def _floatExMy_quantize_kernel(
88
+ x_ptr,
89
+ output_ptr,
90
+ n_elements,
91
+ e_bit,
92
+ m_bit,
93
+ BLOCK_SIZE: tl.constexpr,
94
+ ):
95
+ if isinstance(e_bit, tl.constexpr):
96
+ ebit = e_bit.value
97
+ else:
98
+ ebit = e_bit
99
+
100
+ if isinstance(m_bit, tl.constexpr):
101
+ mbit = m_bit.value
102
+ else:
103
+ mbit = m_bit
104
+
105
+ pid = tl.program_id(axis=0)
106
+ block_start = pid * BLOCK_SIZE
107
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
108
+ mask = offsets < n_elements
109
+ x = tl.load(x_ptr + offsets, mask=mask)
110
+
111
+ x = x.to(tl.float32)
112
+ sign = 1 - 2 * libdevice.signbit(x)
113
+ x_abs = tl.abs(x)
114
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
115
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
116
+ Mhigh = tl.exp2(mbit.to(tl.float32))
117
+ expo = tl.floor(tl.log2(x_abs))
118
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
119
+ mant = x_abs / tl.exp2(expo)
120
+
121
+ mant_int = tl.floor(mant)
122
+ mant_frac = mant - mant_int
123
+ mant_frac = mant_frac * Mhigh
124
+ # mant_frac = mant_frac + noise
125
+ mant_frac = libdevice.round(mant_frac)
126
+
127
+ mant_q = mant_int + mant_frac / Mhigh
128
+ y = sign * tl.exp2(expo) * mant_q
129
+ y = y.to(x_ptr.dtype.element_ty)
130
+
131
+ tl.store(output_ptr + offsets, y, mask=mask)
132
+
133
+
134
+ @triton.autotune(
135
+ configs=[
136
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
137
+ triton.Config(
138
+ {
139
+ "BLOCK_SIZE": 1024,
140
+ },
141
+ num_warps=4,
142
+ ),
143
+ triton.Config(
144
+ {
145
+ "BLOCK_SIZE": 2048,
146
+ },
147
+ num_warps=4,
148
+ ),
149
+ ],
150
+ key=["n_elements"],
151
+ )
152
+ @triton.jit
153
+ def _floatExMy_stochastic_quantize_kernel(
154
+ x_ptr,
155
+ noise_ptr,
156
+ output_ptr,
157
+ n_elements,
158
+ e_bit,
159
+ m_bit,
160
+ BLOCK_SIZE: tl.constexpr,
161
+ ):
162
+ if isinstance(e_bit, tl.constexpr):
163
+ ebit = e_bit.value
164
+ else:
165
+ ebit = e_bit
166
+
167
+ if isinstance(m_bit, tl.constexpr):
168
+ mbit = m_bit.value
169
+ else:
170
+ mbit = m_bit
171
+
172
+ pid = tl.program_id(axis=0)
173
+ block_start = pid * BLOCK_SIZE
174
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
175
+ mask = offsets < n_elements
176
+ x = tl.load(x_ptr + offsets, mask=mask)
177
+ noise = tl.load(noise_ptr + offsets, mask=mask)
178
+
179
+ x = x.to(tl.float32)
180
+ sign = 1 - 2 * libdevice.signbit(x)
181
+ x_abs = tl.abs(x)
182
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
183
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
184
+ Mhigh = tl.exp2(mbit.to(tl.float32))
185
+ expo = tl.floor(tl.log2(x_abs))
186
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
187
+ mant = x_abs / tl.exp2(expo)
188
+
189
+ mant_int = tl.floor(mant)
190
+ mant_frac = mant - mant_int
191
+ mant_frac = mant_frac * Mhigh
192
+ mant_frac = mant_frac + noise
193
+ mant_frac = libdevice.round(mant_frac)
194
+
195
+ mant_q = mant_int + mant_frac / Mhigh
196
+ y = sign * tl.exp2(expo) * mant_q
197
+ y = y.to(x_ptr.dtype.element_ty)
198
+
199
+ tl.store(output_ptr + offsets, y, mask=mask)
llava/model/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel
8
+
9
+ # FP8 related comments, development in progress (PI: ligeng zhu, haochen xi)
10
+ # NOTE: VLM + LLM
11
+ # from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
12
+ # NOTE: Linear -> fp8, similar to transformer engine
13
+ # from .language_model.qllama import QLlamaConfig, QLlamaForCausalLM, QLlamaModel
14
+ # NOTE: Linear + Activation -> fp8, haochen's iclr version
15
+ # from .language_model.qmemllama import QMemLlamaConfig, QMemLlamaForCausalLM, QMemLlamaModel
16
+ """
17
+ TODO:
18
+ linear(weights):
19
+ simulated fp8: done
20
+ real fp8: in-progress (code already implmented)
21
+ activation:
22
+ simulated fp8: done
23
+ real fp8: in-progress (still coding)
24
+ optimizers:
25
+ current VILA: bf16
26
+ simulated fp8: done
27
+ real fp8 + fsdp (single node): done
28
+ real fp8 + fsdp (multiple node): in-progress
29
+ 1. linear fp8
30
+ 2. activation fp8
31
+ 3. fp8 infernce example (load directly from a fp8 and fwd)
32
+ 4. bind fp8 related configs to QLlamaConfig {"coat_fp8_args": {}}
33
+ """
34
+ from .language_model.fp8linearqwen2 import FP8LinearQwen2Config, FP8LinearQwen2Model
35
+ from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
llava/model/apply_delta.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
24
+
25
+ """
26
+ Usage:
27
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
28
+ """
29
+ import argparse
30
+
31
+ import torch
32
+ from tqdm import tqdm
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+
35
+ from llava import LlavaLlamaForCausalLM
36
+
37
+
38
+ def apply_delta(base_model_path, target_model_path, delta_path):
39
+ print("Loading base model")
40
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
41
+
42
+ print("Loading delta")
43
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
44
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
45
+
46
+ print("Applying delta")
47
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
48
+ if name not in base.state_dict():
49
+ assert name in [
50
+ "model.mm_projector.weight",
51
+ "model.mm_projector.bias",
52
+ ], f"{name} not in base model"
53
+ continue
54
+ if param.data.shape == base.state_dict()[name].shape:
55
+ param.data += base.state_dict()[name]
56
+ else:
57
+ assert name in [
58
+ "model.embed_tokens.weight",
59
+ "lm_head.weight",
60
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
61
+ bparam = base.state_dict()[name]
62
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
63
+
64
+ print("Saving target model")
65
+ delta.save_pretrained(target_model_path)
66
+ delta_tokenizer.save_pretrained(target_model_path)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument("--base-model-path", type=str, required=True)
72
+ parser.add_argument("--target-model-path", type=str, required=True)
73
+ parser.add_argument("--delta-path", type=str, required=True)
74
+
75
+ args = parser.parse_args()
76
+
77
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
8
+ # Copyright 2023 Haotian Liu
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import os
24
+ import warnings
25
+
26
+ import torch
27
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PretrainedConfig
28
+
29
+ from llava.model import LlavaLlamaModel
30
+ from llava.model.utils import is_mm_model
31
+
32
+
33
+ def load_pretrained_model(
34
+ model_path,
35
+ model_name,
36
+ model_base=None,
37
+ load_8bit=False,
38
+ load_4bit=False,
39
+ device_map="auto",
40
+ device="cuda",
41
+ **kwargs,
42
+ ):
43
+ kwargs = {"device_map": device_map, **kwargs}
44
+
45
+ if device != "cuda":
46
+ kwargs["device_map"] = {"": device}
47
+
48
+ if load_8bit:
49
+ kwargs["load_in_8bit"] = True
50
+ elif load_4bit:
51
+ kwargs["load_in_4bit"] = True
52
+ kwargs["quantization_config"] = BitsAndBytesConfig(
53
+ load_in_4bit=True,
54
+ bnb_4bit_compute_dtype=torch.float16,
55
+ bnb_4bit_use_double_quant=True,
56
+ bnb_4bit_quant_type="nf4",
57
+ )
58
+ else:
59
+ kwargs["torch_dtype"] = torch.float16
60
+ # kwargs["torch_dtype"] = torch.bfloat16
61
+
62
+ if is_mm_model(model_path):
63
+ # Load LLaVA model
64
+ ## TODO @yunhao: mind fixing lora
65
+ if "lora" in model_name.lower() and model_base is None:
66
+ warnings.warn(
67
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
68
+ )
69
+ if ("lora" in model_name.lower() or "dora" in model_name.lower()) and model_base is not None:
70
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
71
+ print(lora_cfg_pretrained)
72
+ print("Loading LLaVA from base model...")
73
+ config = AutoConfig.from_pretrained(model_base)
74
+ prepare_config_for_eval(config, kwargs)
75
+ model = LlavaLlamaModel.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
76
+ tokenizer = model.tokenizer
77
+ token_num, tokem_dim = model.llm.lm_head.out_features, model.llm.lm_head.in_features
78
+ if model.llm.lm_head.weight.shape[0] != token_num:
79
+ model.llm.lm_head.weight = torch.nn.Parameter(
80
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
81
+ )
82
+ model.llm.embed_tokens.weight = torch.nn.Parameter(
83
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
84
+ )
85
+
86
+ print("Loading additional LLaVA weights...")
87
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
88
+ non_lora_trainables = torch.load(
89
+ os.path.join(model_path, "non_lora_trainables.bin"),
90
+ map_location="cpu",
91
+ )
92
+ else:
93
+ # this is probably from HF Hub
94
+ from huggingface_hub import hf_hub_download
95
+
96
+ def load_from_hf(repo_id, filename, subfolder=None):
97
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
98
+ return torch.load(cache_file, map_location="cpu")
99
+
100
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
101
+ non_lora_trainables = {
102
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
103
+ }
104
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
105
+ non_lora_trainables = {
106
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
107
+ }
108
+ model.load_state_dict(non_lora_trainables, strict=False)
109
+
110
+ from peft import PeftModel
111
+
112
+ print("Loading LoRA weights...")
113
+ model = PeftModel.from_pretrained(model, model_path)
114
+ print("Merging LoRA weights...")
115
+ model = model.merge_and_unload()
116
+ print("Model is loaded...")
117
+ else:
118
+ config = AutoConfig.from_pretrained(model_path)
119
+ config.resume_path = model_path
120
+ prepare_config_for_eval(config, kwargs)
121
+ model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs)
122
+ tokenizer = model.tokenizer
123
+ else:
124
+ # Load language model
125
+ if model_base is not None:
126
+ # PEFT model
127
+ from peft import PeftModel
128
+
129
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
130
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
131
+ print(f"Loading LoRA weights from {model_path}")
132
+ model = PeftModel.from_pretrained(model, model_path)
133
+ print(f"Merging weights")
134
+ model = model.merge_and_unload()
135
+ print("Convert to FP16...")
136
+ model.to(torch.float16)
137
+ else:
138
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, legacy=False)
139
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
140
+ model.eval()
141
+ image_processor = None
142
+ if is_mm_model(model_path):
143
+ model.resize_token_embeddings(len(tokenizer))
144
+
145
+ if hasattr(model.llm.config, "max_sequence_length"):
146
+ context_len = model.config.max_sequence_length
147
+ else:
148
+ context_len = 2048
149
+
150
+ return tokenizer, model, image_processor, context_len
151
+
152
+
153
+ def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
154
+ try:
155
+ # compatible with deprecated config convention
156
+ if getattr(config, "vision_tower_cfg", None) is None:
157
+ config.vision_tower_cfg = config.mm_vision_tower
158
+ except AttributeError:
159
+ raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
160
+
161
+ config.model_dtype = kwargs.pop("torch_dtype").__str__()
llava/model/coat/activation/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import math
24
+
25
+ import torch
26
+
27
+
28
+ def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
29
+ sign, x_abs = x.sign(), x.abs()
30
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
31
+ expo = torch.floor(torch.log2(x_abs))
32
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
33
+ mant = x_abs / torch.exp2(expo)
34
+
35
+ mant_int = torch.floor(mant)
36
+ mant_frac = mant - mant_int
37
+ mant_frac = mant_frac * Mhigh
38
+ if stochastic:
39
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
40
+ mant_frac.add_(noise)
41
+ mant_frac = torch.round(mant_frac)
42
+
43
+ mant_q = mant_int + mant_frac / Mhigh
44
+ y = sign * (2**expo) * mant_q
45
+ y = y.to(x)
46
+
47
+ return y
48
+
49
+
50
+ def floatExM0_quantize_torch(x, e_bit, stochastic):
51
+ sign, x_abs = x.sign(), x.abs()
52
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
53
+ expo = torch.log2(x_abs)
54
+ if stochastic:
55
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
56
+ expo.add(noise)
57
+ log_bias = math.log2(4 / 3) - 1 / 2
58
+ expo.add(torch.ones_like(expo) * log_bias)
59
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
60
+ expo = torch.round(expo)
61
+
62
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
63
+ y = y.to(x)
64
+
65
+ return y
66
+
67
+
68
+ def Dynamic_quantize_torch(x, bit, stochastic):
69
+ if stochastic:
70
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
71
+ sign, x_abs = x.sign(), x.abs()
72
+ expo = torch.ceil(torch.log10(x_abs))
73
+ expo = torch.clamp(expo, min=2 - bit)
74
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
75
+
76
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
77
+ mant_frac = torch.round(mant_frac)
78
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
79
+ y = sign * (10**expo) * mant_frac / 10
80
+
81
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
82
+ y = y * zero_mask
83
+ y = y.to(x)
84
+ return y
85
+
86
+
87
+ def ZeroDynamic_quantize_torch(x, bit, stochastic):
88
+ if stochastic:
89
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
90
+ sign, x_abs = x.sign(), x.abs()
91
+ expo = torch.ceil(torch.log10(x_abs))
92
+ expo = torch.clamp(expo, min=2 - bit)
93
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
94
+
95
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
96
+ mant_frac = torch.round(mant_frac)
97
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
98
+ y = sign * (10**expo) * mant_frac / 10
99
+
100
+ y = y.to(x)
101
+ return y
llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import math
24
+ import struct
25
+
26
+ import numpy as np
27
+ import torch
28
+ import triton
29
+ import triton.language as tl
30
+ from triton.language.extra.cuda import libdevice
31
+
32
+
33
+ def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
34
+ n_elements = x.numel()
35
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
36
+ y = torch.zeros_like(x)
37
+
38
+ if x.dtype in [torch.bfloat16, torch.float32]:
39
+ if stochastic:
40
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
41
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
42
+ else:
43
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
44
+ else:
45
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
46
+
47
+ return y
48
+
49
+
50
+ @triton.autotune(
51
+ configs=[
52
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
53
+ triton.Config(
54
+ {
55
+ "BLOCK_SIZE": 1024,
56
+ },
57
+ num_warps=4,
58
+ ),
59
+ triton.Config(
60
+ {
61
+ "BLOCK_SIZE": 2048,
62
+ },
63
+ num_stages=1,
64
+ ),
65
+ ],
66
+ key=["n_elements"],
67
+ )
68
+ @triton.jit
69
+ def _floatExMy_quantize_kernel(
70
+ x_ptr,
71
+ output_ptr,
72
+ n_elements,
73
+ e_bit,
74
+ m_bit,
75
+ BLOCK_SIZE: tl.constexpr,
76
+ ):
77
+ if isinstance(e_bit, tl.constexpr):
78
+ ebit = e_bit.value
79
+ else:
80
+ ebit = e_bit
81
+
82
+ if isinstance(m_bit, tl.constexpr):
83
+ mbit = m_bit.value
84
+ else:
85
+ mbit = m_bit
86
+
87
+ pid = tl.program_id(axis=0)
88
+ block_start = pid * BLOCK_SIZE
89
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
90
+ mask = offsets < n_elements
91
+ x = tl.load(x_ptr + offsets, mask=mask)
92
+
93
+ x = x.to(tl.float32)
94
+ sign = 1 - 2 * libdevice.signbit(x)
95
+ x_abs = tl.abs(x)
96
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
97
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
98
+ Mhigh = tl.exp2(mbit.to(tl.float32))
99
+ expo = tl.floor(tl.log2(x_abs))
100
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
101
+ mant = x_abs / tl.exp2(expo)
102
+
103
+ mant_int = tl.floor(mant)
104
+ mant_frac = mant - mant_int
105
+ mant_frac = mant_frac * Mhigh
106
+ # mant_frac = mant_frac + noise
107
+ mant_frac = libdevice.round(mant_frac)
108
+
109
+ mant_q = mant_int + mant_frac / Mhigh
110
+ y = sign * tl.exp2(expo) * mant_q
111
+ y = y.to(x_ptr.dtype.element_ty)
112
+
113
+ tl.store(output_ptr + offsets, y, mask=mask)
114
+
115
+
116
+ @triton.autotune(
117
+ configs=[
118
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
119
+ triton.Config(
120
+ {
121
+ "BLOCK_SIZE": 1024,
122
+ },
123
+ num_warps=4,
124
+ ),
125
+ triton.Config(
126
+ {
127
+ "BLOCK_SIZE": 2048,
128
+ },
129
+ num_stages=1,
130
+ ),
131
+ ],
132
+ key=["n_elements"],
133
+ )
134
+ @triton.jit
135
+ def _floatExMy_stochastic_quantize_kernel(
136
+ x_ptr,
137
+ noise_ptr,
138
+ output_ptr,
139
+ n_elements,
140
+ e_bit,
141
+ m_bit,
142
+ BLOCK_SIZE: tl.constexpr,
143
+ ):
144
+ if isinstance(e_bit, tl.constexpr):
145
+ ebit = e_bit.value
146
+ else:
147
+ ebit = e_bit
148
+
149
+ if isinstance(m_bit, tl.constexpr):
150
+ mbit = m_bit.value
151
+ else:
152
+ mbit = m_bit
153
+
154
+ pid = tl.program_id(axis=0)
155
+ block_start = pid * BLOCK_SIZE
156
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
157
+ mask = offsets < n_elements
158
+ x = tl.load(x_ptr + offsets, mask=mask)
159
+ noise = tl.load(noise_ptr + offsets, mask=mask)
160
+
161
+ x = x.to(tl.float32)
162
+ sign = 1 - 2 * libdevice.signbit(x)
163
+ x_abs = tl.abs(x)
164
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
165
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
166
+ Mhigh = tl.exp2(mbit.to(tl.float32))
167
+ expo = tl.floor(tl.log2(x_abs))
168
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
169
+ mant = x_abs / tl.exp2(expo)
170
+
171
+ mant_int = tl.floor(mant)
172
+ mant_frac = mant - mant_int
173
+ mant_frac = mant_frac * Mhigh
174
+ mant_frac = mant_frac + noise
175
+ mant_frac = libdevice.round(mant_frac)
176
+
177
+ mant_q = mant_int + mant_frac / Mhigh
178
+ y = sign * tl.exp2(expo) * mant_q
179
+ y = y.to(x_ptr.dtype.element_ty)
180
+
181
+ tl.store(output_ptr + offsets, y, mask=mask)
llava/model/coat/activation/fake_quantization/quantize_function.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import re
24
+
25
+ import torch
26
+
27
+ from .FloatPointQuantizeTorch import *
28
+ from .FloatPointQuantizeTriton import *
29
+
30
+
31
+ def block_cut(input, row_block, column_block, pad_block=False):
32
+ # print(input.shape)
33
+ original_shape = input.shape
34
+ # input tensor shape is M * N
35
+ if len(input.shape) > 2:
36
+ input = input.reshape(-1, input.shape[2])
37
+ elif len(input.shape) == 2:
38
+ pass
39
+ else:
40
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
41
+ M, N = input.shape[0], input.shape[1]
42
+
43
+ if row_block == -1:
44
+ row_block = M
45
+ if column_block == -1:
46
+ column_block = N
47
+
48
+ if pad_block:
49
+ row_remainder, col_remainder = M % row_block, N % column_block
50
+ if row_remainder:
51
+ row_pad = row_block - row_remainder
52
+ else:
53
+ row_pad = 0
54
+ if col_remainder:
55
+ col_pad = column_block - col_remainder
56
+ else:
57
+ col_pad = 0
58
+
59
+ input = torch.nn.functional.pad(
60
+ input, (0, col_pad, 0, row_pad), "constant", 0
61
+ ) # refer to torch's doc to see why
62
+ M, N = input.shape[0], input.shape[1]
63
+ row_num, column_num = M // row_block, N // column_block
64
+ else:
65
+ row_num, column_num = M // row_block, N // column_block
66
+
67
+ assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
68
+ assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
69
+ # print(input.shape)
70
+ input = (
71
+ input.reshape(row_num, row_block, column_num, column_block)
72
+ .permute(0, 2, 1, 3)
73
+ .reshape(row_num * column_num, row_block, column_block)
74
+ )
75
+ # print(input.shape)
76
+ return input
77
+
78
+
79
+ def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
80
+ if len(origin_input.shape) > 2:
81
+ flatten_input = origin_input.reshape(-1, origin_input.shape[2])
82
+ elif len(origin_input.shape) == 2:
83
+ flatten_input = origin_input
84
+ else:
85
+ raise ValueError(f"input shape {input.shape} does not match for block cut")
86
+
87
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
88
+
89
+ if row_block == -1:
90
+ row_block = M
91
+ if column_block == -1:
92
+ column_block = N
93
+
94
+ if pad_block:
95
+ row_remainder, col_remainder = M % row_block, N % column_block
96
+ if row_remainder:
97
+ row_pad = row_block - row_remainder
98
+ else:
99
+ row_pad = 0
100
+ if col_remainder:
101
+ col_pad = column_block - col_remainder
102
+ else:
103
+ col_pad = 0
104
+
105
+ pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
106
+ M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
107
+ row_num, column_num = M // row_block, N // column_block
108
+ else:
109
+ row_num, column_num = M // row_block, N // column_block
110
+
111
+ input = (
112
+ input.reshape(row_num, column_num, row_block, column_block)
113
+ .permute(0, 2, 1, 3)
114
+ .reshape(row_num * row_block, column_num * column_block)
115
+ )
116
+
117
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
118
+ input = input[:M, :N]
119
+
120
+ if len(origin_input.shape) > 2:
121
+ input = input.reshape(origin_input.shape)
122
+ elif len(origin_input.shape) == 2:
123
+ pass
124
+ else:
125
+ raise ValueError(f"input shape {input.shape} does not match for block reshape")
126
+
127
+ return input
128
+
129
+
130
+ def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
131
+ Binput = block_cut(input, row_block, column_block)
132
+ Binput = Binput.to(torch.float32)
133
+
134
+ for n in range(Binput.shape[0]):
135
+ unique_values = len(torch.unique(Binput[n, :, :]))
136
+ if unique_values > 256:
137
+ if necessary:
138
+ raise ValueError(f"{layer_type} contains more than 256 unique values.")
139
+ else:
140
+ return False
141
+ return True
142
+
143
+
144
+ def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
145
+ Quant_fn = SymmQuantizer
146
+ return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)
147
+
148
+
149
+ def extract_bit(string):
150
+ match = re.match(r"INT(\d+)", string) # INT8
151
+ if match:
152
+ return "integer", int(match.group(1)), None
153
+ match = re.match(r"E(\d+)M(\d+)", string) # E4M3 / E5M2
154
+ if match:
155
+ Ebit, Mbit = int(match.group(1)), int(match.group(2))
156
+ if Ebit == 1:
157
+ return "integer", Mbit + 1, None
158
+ if Mbit == 0:
159
+ return "floatExM0", int(match.group(1)), 0
160
+ return "floatExMy", int(match.group(1)), int(match.group(2))
161
+ match = re.match(r"DE(\d+)", string)
162
+ if match:
163
+ return "Dynamic", int(match.group(1)), None
164
+ match = re.match(r"ZeroD(\d+)", string)
165
+ if match:
166
+ return "ZeroDynamic", int(match.group(1)), None
167
+ raise ValueError(f"{string} data format is not supported")
168
+
169
+
170
+ class SymmQuantizer(torch.autograd.function.InplaceFunction):
171
+ @staticmethod
172
+ def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
173
+ with torch.no_grad():
174
+ absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon
175
+
176
+ if bits == "100" or not apply_quantize:
177
+ return input, input, torch.ones_like(absmax_per_block)
178
+ elif bits == "FP32":
179
+ return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
180
+ elif bits == "FP16":
181
+ return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
182
+ elif bits == "BF16":
183
+ return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
184
+ else:
185
+ QuantType, bit1, bit2 = extract_bit(bits)
186
+ if not symm:
187
+ bit1 = bit1 + 1 # pretend to be asymmtric
188
+
189
+ if QuantType == "integer":
190
+ Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
191
+ elif QuantType == "floatExMy":
192
+ Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
193
+ 2 ** (2 ** (bit1 - 1))
194
+ )
195
+ if bit1 == 4 and bit2 == 3: # E4M3
196
+ Qn, Qp = -448, 448
197
+ if bit1 == 5 and bit2 == 2: # E5M2
198
+ Qn, Qp = -57344, 57344
199
+ elif QuantType == "floatExM0":
200
+ Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
201
+ elif QuantType == "Dynamic":
202
+ Qn, Qp = -1, 1
203
+ elif QuantType == "ZeroDynamic":
204
+ Qn, Qp = -1, 1
205
+ else:
206
+ raise NotImplementedError(f"{bits} is not supported by quantization")
207
+ scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
208
+ scale_per_block = scale_per_block.to(input)
209
+
210
+ Qinput = input / scale_per_block
211
+
212
+ if QuantType == "integer":
213
+ if stochastic:
214
+ noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
215
+ Qinput.add_(noise)
216
+ Qinput.clamp_(Qn, Qp).round_()
217
+ elif QuantType == "floatExMy":
218
+ # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
219
+ Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
220
+ elif QuantType == "floatExM0":
221
+ Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
222
+ else:
223
+ raise NotImplementedError(f"{bits} is not supported by quantization")
224
+
225
+ RQinput = Qinput * scale_per_block
226
+
227
+ if input.dtype != Qinput.dtype:
228
+ print(
229
+ f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
230
+ file=open("debug.txt", "a"),
231
+ )
232
+ import IPython
233
+
234
+ IPython.embed()
235
+ return RQinput, Qinput, scale_per_block
236
+
237
+ @staticmethod
238
+ def backward(ctx, grad_output):
239
+ return grad_output, None, None, None, None, None
llava/model/coat/activation/fake_quantization/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import os
24
+
25
+ import matplotlib.pyplot as plt
26
+ import numpy as np
27
+ import torch
28
+
29
+
30
+ def list_has_common_element(list1, list2):
31
+ set1 = set(list1)
32
+ set2 = set(list2)
33
+ return len(set1.intersection(set2)) > 0
34
+
35
+
36
+ def calculate_scale_num(input, row_block, col_block):
37
+ if len(input.shape) > 2:
38
+ input = input.reshape(-1, input.shape[2])
39
+ elif len(input.shape) == 2:
40
+ pass
41
+ else:
42
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
43
+ M, N = input.shape[0], input.shape[1]
44
+
45
+ if row_block == -1:
46
+ row_block = M
47
+ if col_block == -1:
48
+ col_block = N
49
+
50
+ return input.numel() / (row_block * col_block)
51
+
52
+
53
+ def quant_get_local_rank() -> int:
54
+ return int(os.environ.get("LOCAL_RANK") or 0)
55
+
56
+
57
+ def format_string_with_condition(
58
+ input_string,
59
+ condition_config,
60
+ symm,
61
+ bits,
62
+ blocksize_config,
63
+ input_pad=20,
64
+ ):
65
+ padded_string = input_string.ljust(input_pad)
66
+ output_string = padded_string
67
+
68
+ for k, v in condition_config.items():
69
+ if v:
70
+ output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6)
71
+ else:
72
+ output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6)
73
+
74
+ output_string = output_string + f"Symm {symm}".ljust(10)
75
+
76
+ for k, v in bits.items():
77
+ output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10)
78
+ for k, v in blocksize_config.items():
79
+ output_string += f"{k}: {v}".ljust(15)
80
+
81
+ return output_string
82
+
83
+
84
+ def print_warning(sentence):
85
+ print("*" * (len(sentence) + 4))
86
+ print(f"* {sentence} *")
87
+ print("*" * (len(sentence) + 4))
88
+
89
+
90
+ def check_nan_inf(tensor, check_nan, check_inf):
91
+ if check_nan:
92
+ contain_nan = torch.isnan(tensor).any()
93
+ else:
94
+ contain_nan = False
95
+ if check_inf:
96
+ contain_inf = torch.isinf(tensor).any()
97
+ else:
98
+ contain_inf = False
99
+ return contain_nan, contain_inf
100
+
101
+
102
+ def move_torch_to_numpy(tensor):
103
+ if tensor is None:
104
+ return None
105
+
106
+ if tensor.is_cuda:
107
+ tensor = tensor.cpu()
108
+ return tensor.detach().float().numpy()
109
+
110
+
111
+ def flatten_to_1d(tensor):
112
+ if tensor is None:
113
+ return None
114
+
115
+ return tensor.reshape(-1)
llava/model/coat/activation/models/_fp8_quantization_config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ from dataclasses import dataclass
8
+
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class QuantizationConfig:
14
+ quantize_model: str = "false"
15
+ symm: bool = True
16
+ epsilon: float = 1e-10
17
+ fabit: str = "E4M3"
18
+ fwbit: str = "E4M3"
19
+ fobit: str = "E4M3"
20
+ babit: str = "E5M2"
21
+ bwbit: str = "E5M2"
22
+ bobit: str = "E5M2"
23
+ qchoice: str = "none"
24
+ group_size: int = -1
25
+ pad_to_multiple_of: int = 0
26
+ weight_memory_efficient: bool = True
27
+
28
+ # Legacy
29
+ row_blocksize: int = -1
30
+ col_blocksize: int = -1
31
+
32
+ def __init__(
33
+ self,
34
+ quantize_model: str = "false",
35
+ symm: bool = True,
36
+ epsilon: float = 1e-10,
37
+ fabit: str = "E4M3",
38
+ fwbit: str = "E4M3",
39
+ fobit: str = "E4M3",
40
+ babit: str = "E5M2",
41
+ bwbit: str = "E5M2",
42
+ bobit: str = "E5M2",
43
+ qchoice: str = "none",
44
+ group_size: int = -1,
45
+ pad_to_multiple_of: int = 0,
46
+ weight_memory_efficient: bool = True,
47
+ row_blocksize: int = -1,
48
+ col_blocksize: int = -1,
49
+ **kwargs,
50
+ ):
51
+ super().__init__()
52
+ self.quantize_model = quantize_model
53
+ self.symm = symm
54
+ self.epsilon = epsilon
55
+ self.fabit = fabit
56
+ self.fwbit = fwbit
57
+ self.fobit = fobit
58
+ self.babit = babit
59
+ self.bwbit = bwbit
60
+ self.bobit = bobit
61
+ self.qchoice = qchoice
62
+ self.group_size = group_size
63
+ self.pad_to_multiple_of = pad_to_multiple_of
64
+ self.weight_memory_efficient = weight_memory_efficient
65
+
66
+ self.row_blocksize = row_blocksize
67
+ self.col_blocksize = col_blocksize
llava/model/coat/activation/models/_fp8_weightcache.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch.nn as nn
8
+
9
+ from ..real_quantization import fp8_division_transpose
10
+
11
+
12
+ class FP8CacheWeightModule(nn.Module):
13
+ def __init__(self, config, qargs, layer_id):
14
+ super().__init__()
15
+ self.config = config
16
+ self.qargs = qargs
17
+ self.layer_id = layer_id
18
+
19
+ def prepare_weight(self, weight, weight_name, is_first_microbatch):
20
+ if is_first_microbatch:
21
+ if self.qargs.weight_memory_efficient:
22
+ # print(f"{weight_name} uses first microbatch")
23
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
24
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
25
+ )
26
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
27
+ return weight_fp8, weight_fp8_t, weight_s
28
+ else:
29
+ # print(f"{weight_name} uses first microbatch")
30
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
31
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
32
+ )
33
+ setattr(self, f"{weight_name}_fp8", weight_fp8)
34
+ setattr(self, f"{weight_name}_fp8_t", weight_fp8_t)
35
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
36
+ return weight_fp8, weight_fp8_t, weight_s
37
+ else:
38
+ if self.qargs.weight_memory_efficient:
39
+ return getattr(self, f"{weight_name}_fp8_scale")
40
+ else:
41
+ return (
42
+ getattr(self, f"{weight_name}_fp8"),
43
+ getattr(self, f"{weight_name}_fp8_t"),
44
+ getattr(self, f"{weight_name}_fp8_scale"),
45
+ )
46
+
47
+ def forward(self, x):
48
+ pass
llava/model/coat/activation/models/_fp8manager.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+
26
+ class FP8Manager:
27
+ """Class to keep track of and manipulate the global
28
+ FP8 state at different stages of execution.
29
+ """
30
+
31
+ is_first_microbatch = False
llava/model/coat/activation/models/coat_llama.py ADDED
@@ -0,0 +1,1479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ import math
26
+ import os
27
+ from fnmatch import fnmatch
28
+ from typing import List, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ from torch import nn
34
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
35
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
36
+ from transformers.activations import ACT2FN
37
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
38
+ from transformers.generation import GenerationMixin
39
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
40
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPast,
43
+ CausalLMOutputWithPast,
44
+ QuestionAnsweringModelOutput,
45
+ SequenceClassifierOutputWithPast,
46
+ TokenClassifierOutput,
47
+ )
48
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
49
+ from transformers.modeling_utils import PreTrainedModel
50
+ from transformers.models.llama.configuration_llama import LlamaConfig
51
+ from transformers.models.llama.modeling_llama import (
52
+ LlamaAttention,
53
+ LlamaDynamicNTKScalingRotaryEmbedding,
54
+ LlamaForCausalLM,
55
+ LlamaLinearScalingRotaryEmbedding,
56
+ LlamaModel,
57
+ LlamaPreTrainedModel,
58
+ LlamaRMSNorm,
59
+ LlamaRotaryEmbedding,
60
+ _prepare_4d_causal_attention_mask_with_cache_position,
61
+ apply_rotary_pos_emb,
62
+ repeat_kv,
63
+ rotate_half,
64
+ )
65
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
66
+ from transformers.utils import (
67
+ add_start_docstrings,
68
+ add_start_docstrings_to_model_forward,
69
+ is_flash_attn_greater_or_equal_2_10,
70
+ is_torchdynamo_compiling,
71
+ logging,
72
+ replace_return_docstrings,
73
+ )
74
+
75
+ from ..real_quantization import (
76
+ Coat_quantize_bgn,
77
+ Coat_quantize_end,
78
+ fp8_add_Ifp_Ifp_Ofp_Og16,
79
+ fp8_add_Ifp_Ifp_Ofp_Opt,
80
+ fp8_division,
81
+ fp8_division_transpose,
82
+ fp8_gelu_backward,
83
+ fp8_gelu_forward,
84
+ fp8_layernorm_noparam_backward,
85
+ fp8_layernorm_noparam_forward,
86
+ fp8_linear_backward,
87
+ fp8_linear_forward,
88
+ fp8_mul_backward,
89
+ fp8_mul_forward,
90
+ fp8_quantize,
91
+ fp8_quantize_pertensor,
92
+ fp8_quantize_pertensor_transpose,
93
+ fp8_rmsnorm_backward,
94
+ fp8_rmsnorm_forward,
95
+ fp8_silu_backward,
96
+ fp8_silu_forward,
97
+ fp8_transpose,
98
+ )
99
+
100
+ # FP8 related
101
+ from ._fp8_quantization_config import QuantizationConfig
102
+ from ._fp8_weightcache import FP8CacheWeightModule
103
+ from ._fp8manager import FP8Manager
104
+
105
+ logger = logging.get_logger(__name__)
106
+
107
+
108
+ class CoatLlamaConfig(LlamaConfig):
109
+ model_type = "fp8_llama"
110
+
111
+
112
+ class CoatLlamaBeforeAttentionResidual(FP8CacheWeightModule):
113
+ """
114
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
115
+ """
116
+
117
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_idx: Optional[int] = None):
118
+ super().__init__(config, qargs, layer_idx)
119
+
120
+ self.qargs = qargs
121
+ self.fwobits = {
122
+ "fabit": self.qargs.fabit,
123
+ "fwbit": self.qargs.fwbit,
124
+ "fobit": self.qargs.fobit,
125
+ "babit": self.qargs.babit,
126
+ "bwbit": self.qargs.bwbit,
127
+ "bobit": self.qargs.bobit,
128
+ }
129
+
130
+ self.config = config
131
+ self.layer_idx = layer_idx
132
+ if layer_idx is None:
133
+ logger.warning_once(
134
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
135
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
136
+ "when creating this class."
137
+ )
138
+
139
+ self.attention_dropout = config.attention_dropout
140
+ self.hidden_size = config.hidden_size
141
+ self.num_heads = config.num_attention_heads
142
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
143
+ self.num_key_value_heads = config.num_key_value_heads
144
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
145
+ self.max_position_embeddings = config.max_position_embeddings
146
+ self.rope_theta = config.rope_theta
147
+
148
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
149
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
150
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
151
+
152
+ def forward(self, re_x, x, s, rmsnorm_weight):
153
+ if self.training:
154
+ if self.qargs.weight_memory_efficient:
155
+ # Prepare
156
+ with torch.no_grad():
157
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
158
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
159
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
160
+ return _CoatLlamaBeforeAttentionResidual.apply(
161
+ re_x,
162
+ x,
163
+ s,
164
+ self.q_proj.weight,
165
+ None,
166
+ None,
167
+ weight1_s,
168
+ self.k_proj.weight,
169
+ None,
170
+ None,
171
+ weight2_s,
172
+ self.v_proj.weight,
173
+ None,
174
+ None,
175
+ weight3_s,
176
+ rmsnorm_weight,
177
+ self.qargs.group_size,
178
+ self.fwobits,
179
+ self.layer_id,
180
+ self.config,
181
+ self.qargs,
182
+ )
183
+ else:
184
+ # Prepare
185
+ with torch.no_grad():
186
+ weight1, weight1_t, weight1_s = self.prepare_weight(
187
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
188
+ )
189
+ weight2, weight2_t, weight2_s = self.prepare_weight(
190
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
191
+ )
192
+ weight3, weight3_t, weight3_s = self.prepare_weight(
193
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
194
+ )
195
+ return _CoatLlamaBeforeAttentionResidual.apply(
196
+ re_x,
197
+ x,
198
+ s,
199
+ self.q_proj.weight,
200
+ weight1,
201
+ weight1_t,
202
+ weight1_s,
203
+ self.k_proj.weight,
204
+ weight2,
205
+ weight2_t,
206
+ weight2_s,
207
+ self.v_proj.weight,
208
+ weight3,
209
+ weight3_t,
210
+ weight3_s,
211
+ rmsnorm_weight,
212
+ self.qargs.group_size,
213
+ self.fwobits,
214
+ self.layer_id,
215
+ self.config,
216
+ self.qargs,
217
+ )
218
+ else:
219
+ return re_x, self.att_proj(self.attn_norm(re_x))
220
+
221
+
222
+ class _CoatLlamaBeforeAttentionResidual(torch.autograd.Function):
223
+ @staticmethod
224
+ def forward(
225
+ ctx,
226
+ re_x,
227
+ in_x,
228
+ in_s,
229
+ weight1_origin,
230
+ weight1,
231
+ weight1_t,
232
+ weight1_s,
233
+ weight2_origin,
234
+ weight2,
235
+ weight2_t,
236
+ weight2_s,
237
+ weight3_origin,
238
+ weight3,
239
+ weight3_t,
240
+ weight3_s,
241
+ rmsnorm_weight,
242
+ group_size,
243
+ fwobits,
244
+ layer_id,
245
+ config,
246
+ qargs,
247
+ eps=1e-5,
248
+ ):
249
+ # for autograd
250
+ if fwobits["fabit"] == "E4M3":
251
+ # in_x = in_x.to(torch.float8_e4m3fn)
252
+ in_x = in_x.view(torch.float8_e4m3fn)
253
+ else:
254
+ raise ValueError("fabit should be E4M3")
255
+
256
+ # LayerNorm
257
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
258
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
259
+ )
260
+
261
+ # Linear Layer QKV Projection
262
+ if qargs.weight_memory_efficient:
263
+ assert weight1 is None # memory efficient
264
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
265
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
266
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
267
+
268
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) # query states
269
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size) # key states
270
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size) # value states
271
+
272
+ # ==================== save for backward ====================
273
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
274
+ if qargs.weight_memory_efficient:
275
+ assert weight1_t is None and weight2_t is None and weight3_t is None
276
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
277
+ else:
278
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
279
+
280
+ ctx.group_size = group_size
281
+ ctx.ln_utils = ln_utils
282
+ ctx.utils = fwobits, layer_id, config, qargs
283
+
284
+ return re_x, fc1_x, fc2_x, fc3_x
285
+
286
+ @staticmethod
287
+ def backward(ctx, fp_grad, query_g, key_g, value_g):
288
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
289
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
290
+
291
+ group_size = ctx.group_size
292
+ rms_weight, rstd, num_warps = ctx.ln_utils
293
+ fwobits, layer_id, config, qargs = ctx.utils
294
+
295
+ # ==================== Begin backward ====================
296
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
297
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
298
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
299
+ )
300
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
301
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
302
+ )
303
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
304
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
305
+ )
306
+
307
+ # Linear Layer QKV Projection
308
+ if qargs.weight_memory_efficient:
309
+ weight1_t, weight1_s = fp8_division_transpose(
310
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
311
+ )
312
+ weight2_t, weight2_s = fp8_division_transpose(
313
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
314
+ )
315
+ weight3_t, weight3_s = fp8_division_transpose(
316
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
317
+ )
318
+
319
+ fc1_g1, att_q_wg = fp8_linear_backward(
320
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
321
+ )
322
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
323
+ fc1_g3, att_v_wg = fp8_linear_backward(
324
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
325
+ )
326
+
327
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
328
+
329
+ # LayerNorm
330
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
331
+
332
+ # Add the gradient together, and prepare the input of the next layer.
333
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
334
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
335
+ )
336
+
337
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
338
+ in_g = in_g.view(torch.float8_e4m3fn)
339
+
340
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
341
+ return (
342
+ re_g,
343
+ in_g,
344
+ in_sg_g16,
345
+ att_q_wg,
346
+ None,
347
+ None,
348
+ None,
349
+ att_k_wg,
350
+ None,
351
+ None,
352
+ None,
353
+ att_v_wg,
354
+ None,
355
+ None,
356
+ None,
357
+ rms_weight_grad,
358
+ None,
359
+ None,
360
+ None,
361
+ None,
362
+ None,
363
+ None,
364
+ )
365
+
366
+
367
+ class CoatLlamaAfterAttentionResidual(FP8CacheWeightModule):
368
+ """
369
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
370
+ """
371
+
372
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id):
373
+ super().__init__(config, qargs, layer_id)
374
+
375
+ self.qargs = qargs
376
+ self.fwobits = {
377
+ "fabit": self.qargs.fabit,
378
+ "fwbit": self.qargs.fwbit,
379
+ "fobit": self.qargs.fobit,
380
+ "babit": self.qargs.babit,
381
+ "bwbit": self.qargs.bwbit,
382
+ "bobit": self.qargs.bobit,
383
+ }
384
+
385
+ self.hidden_size = config.hidden_size
386
+ self.num_heads = config.num_attention_heads
387
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
388
+
389
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
390
+
391
+ def forward(self, re_x, in_x):
392
+ if self.training:
393
+ if self.qargs.weight_memory_efficient:
394
+ # prepare for the weight
395
+ with torch.no_grad():
396
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
397
+
398
+ return _CoatLlamaAfterAttentionResidual.apply(
399
+ re_x,
400
+ in_x,
401
+ self.o_proj.weight,
402
+ None,
403
+ None,
404
+ weight4_s,
405
+ self.qargs.group_size,
406
+ self.fwobits,
407
+ self.layer_id,
408
+ self.config,
409
+ self.qargs,
410
+ )
411
+ else:
412
+ # prepare for the weight
413
+ with torch.no_grad():
414
+ weight4, weight4_t, weight4_s = self.prepare_weight(
415
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
416
+ )
417
+
418
+ return _CoatLlamaAfterAttentionResidual.apply(
419
+ re_x,
420
+ in_x,
421
+ self.o_proj.weight,
422
+ weight4,
423
+ weight4_t,
424
+ weight4_s,
425
+ self.qargs.group_size,
426
+ self.fwobits,
427
+ self.layer_id,
428
+ self.config,
429
+ self.qargs,
430
+ )
431
+ else:
432
+ return re_x + self.attn_out(in_x), None, None
433
+
434
+
435
+ class _CoatLlamaAfterAttentionResidual(torch.autograd.Function):
436
+ @staticmethod
437
+ def forward(
438
+ ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs
439
+ ):
440
+ # Quantize the FlashAttention Output
441
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
442
+ flash_x, group_size, fwobits["fabit"]
443
+ ) # Modified to make it memory efficient
444
+
445
+ # # Attention Projection Linear Layer
446
+ if qargs.weight_memory_efficient:
447
+ assert weight4 is None # memory efficient
448
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
449
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
450
+
451
+ # import IPython
452
+ # IPython.embed()
453
+ # Add the activations together
454
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size)
455
+
456
+ # ==================== save for backward ====================
457
+ ctx.save_for_backward(flash_x, flash_s)
458
+ if qargs.weight_memory_efficient:
459
+ assert weight4_t is None
460
+ ctx.weight = weight4_origin, weight4_s
461
+ else:
462
+ ctx.weight = weight4_t, weight4_s
463
+ ctx.group_size = group_size
464
+ ctx.fwobits = fwobits
465
+ ctx.utils = fwobits, layer_id, config, qargs
466
+
467
+ # For autograd
468
+ out_x = out_x.view(torch.float8_e4m3fn)
469
+
470
+ return fp_x, out_x, out_s
471
+
472
+ @staticmethod
473
+ def backward(ctx, fp_grad, out_g, out_gs):
474
+ flash_x, flash_s = ctx.saved_tensors
475
+ weight4_t, weight4_s = ctx.weight
476
+ group_size = ctx.group_size
477
+ fwobits = ctx.fwobits
478
+ fwobits, layer_id, config, qargs = ctx.utils
479
+
480
+ # for autograd
481
+ if fwobits["babit"] == "E5M2":
482
+ # out_g = out_g.to(torch.float8_e5m2)
483
+ out_g = out_g.view(torch.float8_e5m2)
484
+ else:
485
+ raise ValueError("babit should be E5M2")
486
+ out_gs_max = out_gs.max()
487
+
488
+ # ==================== Begin backward ====================
489
+ # Output Projection
490
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
491
+
492
+ # We do not save an extra flash_x to save the memory usage
493
+ flash_x_t, flash_s = fp8_division_transpose(
494
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
495
+ )
496
+
497
+ if qargs.weight_memory_efficient:
498
+ weight4_t, weight4_s = fp8_division_transpose(
499
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
500
+ )
501
+ fc4_g, attn_out_wg = fp8_linear_backward(
502
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
503
+ )
504
+
505
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
506
+
507
+
508
+ class CoatLlamaMLPResidual(FP8CacheWeightModule):
509
+ """
510
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
511
+ (4) GELU / Silu Activation
512
+ """
513
+
514
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id, hidden_size: int):
515
+ super().__init__(config, qargs, layer_id)
516
+
517
+ self.qargs = qargs
518
+ self.fwobits = {
519
+ "fabit": self.qargs.fabit,
520
+ "fwbit": self.qargs.fwbit,
521
+ "fobit": self.qargs.fobit,
522
+ "babit": self.qargs.babit,
523
+ "bwbit": self.qargs.bwbit,
524
+ "bobit": self.qargs.bobit,
525
+ }
526
+
527
+ self.config = config
528
+ self.hidden_size = config.hidden_size
529
+ self.intermediate_size = config.intermediate_size
530
+
531
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
532
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
533
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
534
+ self.training = True
535
+
536
+ # below is only used when training = False
537
+ assert config.hidden_act == "silu", "We only support silu activation currently"
538
+ self.act_fn = ACT2FN[config.hidden_act]
539
+
540
+ def forward(self, re_x, x, s, rmsnorm_weight):
541
+ if self.training:
542
+ if self.qargs.weight_memory_efficient: # prepare for the weight
543
+ with torch.no_grad():
544
+ weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch)
545
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
546
+ weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch)
547
+
548
+ return _CoatLlamaMLPResidual.apply(
549
+ re_x,
550
+ x,
551
+ s,
552
+ self.gate_proj.weight,
553
+ None,
554
+ None,
555
+ weight1_s,
556
+ self.up_proj.weight,
557
+ None,
558
+ None,
559
+ weight2_s,
560
+ self.down_proj.weight,
561
+ None,
562
+ None,
563
+ weight3_s,
564
+ rmsnorm_weight,
565
+ self.qargs.group_size,
566
+ self.fwobits,
567
+ self.layer_id,
568
+ self.config,
569
+ self.qargs,
570
+ )
571
+ else:
572
+ # prepare for the weight
573
+ with torch.no_grad():
574
+ weight1, weight1_t, weight1_s = self.prepare_weight(
575
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
576
+ )
577
+ weight2, weight2_t, weight2_s = self.prepare_weight(
578
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
579
+ )
580
+ weight3, weight3_t, weight3_s = self.prepare_weight(
581
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
582
+ )
583
+
584
+ return _CoatLlamaMLPResidual.apply(
585
+ re_x,
586
+ x,
587
+ s,
588
+ self.gate_proj.weight,
589
+ weight1,
590
+ weight1_t,
591
+ weight1_s,
592
+ self.up_proj.weight,
593
+ weight2,
594
+ weight2_t,
595
+ weight2_s,
596
+ self.down_proj.weight,
597
+ weight3,
598
+ weight3_t,
599
+ weight3_s,
600
+ rmsnorm_weight,
601
+ self.qargs.group_size,
602
+ self.fwobits,
603
+ self.layer_id,
604
+ self.config,
605
+ self.qargs,
606
+ )
607
+ else:
608
+ raise NotImplementedError("Need TODO")
609
+ og_x = re_x
610
+ re_x = self.ff_norm(re_x)
611
+ re_x = self.ff_proj(re_x)
612
+ re_x = self.act(re_x)
613
+ re_x = self.ff_out(re_x)
614
+ re_x = og_x + re_x
615
+ return re_x, None, None
616
+
617
+
618
+ class _CoatLlamaMLPResidual(torch.autograd.Function):
619
+ @staticmethod
620
+ def forward(
621
+ ctx,
622
+ re_x,
623
+ in_x,
624
+ in_s,
625
+ weight1_origin,
626
+ weight1,
627
+ weight1_t,
628
+ weight1_s,
629
+ weight2_origin,
630
+ weight2,
631
+ weight2_t,
632
+ weight2_s,
633
+ weight3_origin,
634
+ weight3,
635
+ weight3_t,
636
+ weight3_s,
637
+ rmsnorm_weight,
638
+ group_size,
639
+ fwobits,
640
+ layer_id,
641
+ config,
642
+ qargs,
643
+ eps=1e-5,
644
+ ):
645
+ # For autograd
646
+ if fwobits["fabit"] == "E4M3":
647
+ # in_x = in_x.to(torch.float8_e4m3fn)
648
+ in_x = in_x.view(torch.float8_e4m3fn)
649
+ else:
650
+ raise ValueError("fabit should be E4M3")
651
+
652
+ # LayerNorm
653
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
654
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
655
+ )
656
+
657
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
658
+ if qargs.weight_memory_efficient:
659
+ assert weight1 is None and weight2 is None and weight3 is None # memory efficient
660
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
661
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
662
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
663
+
664
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
665
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
666
+
667
+ # silu Activation
668
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
669
+
670
+ # Element-wise Multiplication
671
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
672
+
673
+ # Output Projection
674
+ if weight3 is None: # memory efficient
675
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
676
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
677
+
678
+ # Add the activation together
679
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
680
+
681
+ # ==================== save for backward ====================
682
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
683
+
684
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
685
+ if (
686
+ qargs.weight_memory_efficient
687
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
688
+ assert weight1_t is None and weight2_t is None and weight3_t is None
689
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
690
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
691
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
692
+
693
+ ctx.group_size = group_size
694
+ ctx.ln_utils = ln_utils
695
+ ctx.utils = fwobits, layer_id, config, qargs
696
+
697
+ out_x = out_x.view(torch.float8_e4m3fn)
698
+
699
+ return fp_x, out_x, out_s
700
+
701
+ @staticmethod
702
+ def backward(ctx, fp_grad, out_g, out_gs):
703
+ fwobits, layer_id, config, qargs = ctx.utils
704
+
705
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
706
+
707
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
708
+ group_size = ctx.group_size
709
+ rms_weight, rstd, num_warps = ctx.ln_utils
710
+ fwobits, layer_id, config, qargs = ctx.utils
711
+
712
+ # For autograd
713
+ if fwobits["babit"] == "E5M2":
714
+ # out_g = out_g.to(torch.float8_e5m2)
715
+ out_g = out_g.view(torch.float8_e5m2)
716
+ else:
717
+ raise ValueError("babit should be E5M2")
718
+ out_gs_max = out_gs.max()
719
+
720
+ # ==================== Begin backward ====================
721
+ # Output Projection
722
+ out_gs = out_gs.max()
723
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
724
+
725
+ if qargs.weight_memory_efficient:
726
+ weight3_t, weight3_s = fp8_division_transpose(
727
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
728
+ )
729
+ fc3_g, weight3_grad = fp8_linear_backward(
730
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
731
+ )
732
+
733
+ # [MEM TEST]
734
+ del out_g, out_g_t, weight3_t
735
+
736
+ # Element-wise Multiplication, 1 means gate, 2 means up
737
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
738
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
739
+ )
740
+
741
+ # Silu activation
742
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
743
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
744
+ )
745
+
746
+ # Linear Layer of Up and Gate Projection
747
+ if qargs.weight_memory_efficient:
748
+ weight1_t, weight1_s = fp8_division_transpose(
749
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
750
+ )
751
+ weight2_t, weight2_s = fp8_division_transpose(
752
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
753
+ )
754
+
755
+ # Gate Proj
756
+ fc1_g, weight1_grad = fp8_linear_backward(
757
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
758
+ )
759
+ fc2_g, weight2_grad = fp8_linear_backward(
760
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
761
+ )
762
+
763
+ fc_g = fc1_g + fc2_g
764
+
765
+ # layerNorm
766
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
767
+
768
+ # Add the gradient together
769
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
770
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
771
+ )
772
+
773
+ in_g = in_g.view(torch.float8_e4m3fn)
774
+
775
+ return (
776
+ re_g,
777
+ in_g,
778
+ in_sg_g16,
779
+ weight1_grad,
780
+ None,
781
+ None,
782
+ None,
783
+ weight2_grad,
784
+ None,
785
+ None,
786
+ None,
787
+ weight3_grad,
788
+ None,
789
+ None,
790
+ None,
791
+ rms_weight_grad,
792
+ None,
793
+ None,
794
+ None,
795
+ None,
796
+ None,
797
+ None,
798
+ )
799
+
800
+
801
+ class LlamaAttentionWithoutLinear(nn.Module):
802
+ """
803
+ Remove the Q/K/V/O projection layer in LlamaAttention module and only calculate the attention logic.
804
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
805
+ """
806
+
807
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.layer_idx = layer_idx
811
+ if layer_idx is None:
812
+ logger.warning_once(
813
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
814
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
815
+ "when creating this class."
816
+ )
817
+
818
+ self.attention_dropout = config.attention_dropout
819
+ self.hidden_size = config.hidden_size
820
+ self.num_heads = config.num_attention_heads
821
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
822
+ self.num_key_value_heads = config.num_key_value_heads
823
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
824
+ self.max_position_embeddings = config.max_position_embeddings
825
+ self.rope_theta = config.rope_theta
826
+ self.is_causal = True
827
+
828
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
829
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
830
+
831
+ def forward(
832
+ self,
833
+ query_states: torch.Tensor,
834
+ key_states: torch.Tensor,
835
+ value_states: torch.Tensor,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.LongTensor] = None,
838
+ past_key_value: Optional[Cache] = None,
839
+ output_attentions: bool = False,
840
+ use_cache: bool = False,
841
+ cache_position: Optional[torch.LongTensor] = None,
842
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
843
+ **kwargs,
844
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
845
+ bsz, q_len, _ = query_states.size()
846
+
847
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
848
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
849
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
850
+
851
+ if position_embeddings is None:
852
+ logger.warning_once(
853
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
854
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
855
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
856
+ "removed and `position_embeddings` will be mandatory."
857
+ )
858
+ cos, sin = self.rotary_emb(value_states, position_ids)
859
+ else:
860
+ cos, sin = position_embeddings
861
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
862
+
863
+ if past_key_value is not None:
864
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
865
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
866
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
867
+
868
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
869
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
870
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
871
+
872
+ if attention_mask is not None: # no matter the length, we just slice it
873
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
874
+ attn_weights = attn_weights + causal_mask
875
+
876
+ # upcast attention to fp32
877
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
878
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
879
+ attn_output = torch.matmul(attn_weights, value_states)
880
+
881
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
882
+ raise ValueError(
883
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
884
+ f" {attn_output.size()}"
885
+ )
886
+
887
+ attn_output = attn_output.transpose(1, 2).contiguous()
888
+
889
+ attn_output = attn_output.reshape(bsz, q_len, -1)
890
+
891
+ if not output_attentions:
892
+ attn_weights = None
893
+
894
+ return attn_output, attn_weights, past_key_value
895
+
896
+
897
+ class LlamaFlashAttention2WithoutLinear(LlamaAttentionWithoutLinear):
898
+ """
899
+ Remove the Q/K/V/O projection layer in LlamaFlashAttention2 module and only calculate the attention logic.
900
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
901
+ """
902
+
903
+ def __init__(self, *args, **kwargs):
904
+ super().__init__(*args, **kwargs)
905
+
906
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
907
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
908
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
909
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
910
+
911
+ def forward(
912
+ self,
913
+ query_states: torch.Tensor,
914
+ key_states: torch.Tensor,
915
+ value_states: torch.Tensor,
916
+ attention_mask: Optional[torch.LongTensor] = None,
917
+ position_ids: Optional[torch.LongTensor] = None,
918
+ past_key_value: Optional[Cache] = None,
919
+ output_attentions: bool = False,
920
+ use_cache: bool = False,
921
+ cache_position: Optional[torch.LongTensor] = None,
922
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
923
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
924
+ if isinstance(past_key_value, StaticCache):
925
+ raise ValueError(
926
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
927
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
928
+ )
929
+
930
+ output_attentions = False
931
+
932
+ bsz, q_len, _ = query_states.size()
933
+
934
+ # Flash attention requires the input to have the shape
935
+ # batch_size x seq_length x head_dim x hidden_dim
936
+ # therefore we just need to keep the original shape
937
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
938
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
939
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
940
+
941
+ if position_embeddings is None:
942
+ logger.warning_once(
943
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
944
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
945
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
946
+ "removed and `position_embeddings` will be mandatory."
947
+ )
948
+ cos, sin = self.rotary_emb(value_states, position_ids)
949
+ else:
950
+ cos, sin = position_embeddings
951
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
952
+
953
+ if past_key_value is not None:
954
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
955
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
956
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
957
+
958
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
959
+ # to be able to avoid many of these transpose/reshape/view.
960
+ query_states = query_states.transpose(1, 2)
961
+ key_states = key_states.transpose(1, 2)
962
+ value_states = value_states.transpose(1, 2)
963
+
964
+ dropout_rate = self.attention_dropout if self.training else 0.0
965
+
966
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
967
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
968
+ # cast them back in the correct dtype just to be sure everything works as expected.
969
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
970
+ # in fp32. (LlamaRMSNorm handles it correctly)
971
+
972
+ input_dtype = query_states.dtype
973
+ if input_dtype == torch.float32:
974
+ if torch.is_autocast_enabled():
975
+ target_dtype = torch.get_autocast_gpu_dtype()
976
+ # Handle the case where the model is quantized
977
+ elif hasattr(self.config, "_pre_quantization_dtype"):
978
+ target_dtype = self.config._pre_quantization_dtype
979
+ else:
980
+ target_dtype = self.q_proj.weight.dtype
981
+
982
+ logger.warning_once(
983
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
984
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
985
+ f" {target_dtype}."
986
+ )
987
+
988
+ query_states = query_states.to(target_dtype)
989
+ key_states = key_states.to(target_dtype)
990
+ value_states = value_states.to(target_dtype)
991
+
992
+ attn_output = _flash_attention_forward(
993
+ query_states,
994
+ key_states,
995
+ value_states,
996
+ attention_mask,
997
+ q_len,
998
+ position_ids=position_ids,
999
+ dropout=dropout_rate,
1000
+ sliding_window=getattr(self, "sliding_window", None),
1001
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
1002
+ is_causal=self.is_causal,
1003
+ )
1004
+
1005
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
1006
+
1007
+ if not output_attentions:
1008
+ attn_weights = None
1009
+
1010
+ return attn_output, attn_weights, past_key_value
1011
+
1012
+
1013
+ class LlamaSdpaAttentionWithoutLinear(LlamaAttentionWithoutLinear):
1014
+ """
1015
+ Remove the Q/K/V/O projection layer in LlamaSdpaAttention module and only calculate the attention logic.
1016
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
1017
+ """
1018
+
1019
+ # Adapted from LlamaAttention.forward
1020
+ def forward(
1021
+ self,
1022
+ query_states: torch.Tensor,
1023
+ key_states: torch.Tensor,
1024
+ value_states: torch.Tensor,
1025
+ attention_mask: Optional[torch.Tensor] = None,
1026
+ position_ids: Optional[torch.LongTensor] = None,
1027
+ past_key_value: Optional[Cache] = None,
1028
+ output_attentions: bool = False,
1029
+ use_cache: bool = False,
1030
+ cache_position: Optional[torch.LongTensor] = None,
1031
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1032
+ **kwargs,
1033
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1034
+ if output_attentions:
1035
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1036
+ logger.warning_once(
1037
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1038
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1039
+ )
1040
+ return super().forward(
1041
+ query_states=query_states,
1042
+ key_states=key_states,
1043
+ value_states=value_states,
1044
+ attention_mask=attention_mask,
1045
+ position_ids=position_ids,
1046
+ past_key_value=past_key_value,
1047
+ output_attentions=output_attentions,
1048
+ use_cache=use_cache,
1049
+ cache_position=cache_position,
1050
+ position_embeddings=position_embeddings,
1051
+ )
1052
+
1053
+ bsz, q_len, _ = query_states.size()
1054
+
1055
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1056
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1057
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1058
+
1059
+ if position_embeddings is None:
1060
+ logger.warning_once(
1061
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
1062
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
1063
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
1064
+ "removed and `position_embeddings` will be mandatory."
1065
+ )
1066
+ cos, sin = self.rotary_emb(value_states, position_ids)
1067
+ else:
1068
+ cos, sin = position_embeddings
1069
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
1070
+
1071
+ if past_key_value is not None:
1072
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
1073
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
1074
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1075
+
1076
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1077
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1078
+
1079
+ causal_mask = attention_mask
1080
+ if attention_mask is not None:
1081
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
1082
+
1083
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1084
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1085
+ if query_states.device.type == "cuda" and causal_mask is not None:
1086
+ query_states = query_states.contiguous()
1087
+ key_states = key_states.contiguous()
1088
+ value_states = value_states.contiguous()
1089
+
1090
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1091
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1092
+ is_causal = True if causal_mask is None and q_len > 1 else False
1093
+
1094
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1095
+ query_states,
1096
+ key_states,
1097
+ value_states,
1098
+ attn_mask=causal_mask,
1099
+ dropout_p=self.attention_dropout if self.training else 0.0,
1100
+ is_causal=is_causal,
1101
+ )
1102
+
1103
+ attn_output = attn_output.transpose(1, 2).contiguous()
1104
+ attn_output = attn_output.view(bsz, q_len, -1)
1105
+
1106
+ return attn_output, None, past_key_value
1107
+
1108
+
1109
+ COAT_LLAMA_ATTENTION_CLASSES = {
1110
+ "eager": LlamaAttentionWithoutLinear,
1111
+ "flash_attention_2": LlamaFlashAttention2WithoutLinear,
1112
+ "sdpa": LlamaSdpaAttentionWithoutLinear,
1113
+ }
1114
+
1115
+
1116
+ class CoatLlamaDecoderLayer(nn.Module):
1117
+ def __init__(self, config: CoatLlamaConfig, layer_idx: int):
1118
+ super().__init__()
1119
+ self.layer_idx = layer_idx
1120
+ self.hidden_size = config.hidden_size
1121
+
1122
+ self.self_attn = COAT_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1123
+
1124
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
1125
+ self.BeforeAttention = CoatLlamaBeforeAttentionResidual(config, self.qargs, layer_idx)
1126
+ self.AfterAttention = CoatLlamaAfterAttentionResidual(config, self.qargs, layer_idx)
1127
+ self.MLPResidual = CoatLlamaMLPResidual(config, self.qargs, layer_idx, self.hidden_size)
1128
+
1129
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1130
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1131
+
1132
+ def forward(
1133
+ self,
1134
+ hidden_states: torch.Tensor,
1135
+ quant_hidden_states: torch.Tensor,
1136
+ scale_hidden_states: torch.Tensor,
1137
+ attention_mask: Optional[torch.Tensor] = None,
1138
+ position_ids: Optional[torch.LongTensor] = None,
1139
+ past_key_value: Optional[Cache] = None,
1140
+ output_attentions: Optional[bool] = False,
1141
+ use_cache: Optional[bool] = False,
1142
+ cache_position: Optional[torch.LongTensor] = None,
1143
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1144
+ **kwargs,
1145
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1146
+ """
1147
+ Args:
1148
+ hidden_states (`torch.FloatTensor`): BF16 input to the layer of shape `(batch, seq_len, embed_dim)`
1149
+ quant_hidden_states (`torch.float8_e4m3fn`): FP8 input to the layer of shape `(batch, seq_len, embed_dim)`
1150
+ scale_hidden_states (`torch.bfloat16`): BF16 scaling factor to the layer of shape `(batch, seq_len, embed_dim // group_size)`
1151
+ attention_mask (`torch.FloatTensor`, *optional*):
1152
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1153
+ query_sequence_length, key_sequence_length)` if default attention is used.
1154
+ output_attentions (`bool`, *optional*):
1155
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1156
+ returned tensors for more detail.
1157
+ use_cache (`bool`, *optional*):
1158
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1159
+ (see `past_key_values`).
1160
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1161
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1162
+ Indices depicting the position of the input sequence tokens in the sequence
1163
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1164
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1165
+ with `head_dim` being the embedding dimension of each attention head.
1166
+ kwargs (`dict`, *optional*):
1167
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1168
+ into the model
1169
+ """
1170
+
1171
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
1172
+ residual, query_states, key_states, value_states = self.BeforeAttention(
1173
+ hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
1174
+ )
1175
+
1176
+ # Self Attention without any linear layer
1177
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1178
+ query_states=query_states,
1179
+ key_states=key_states,
1180
+ value_states=value_states,
1181
+ attention_mask=attention_mask,
1182
+ position_ids=position_ids,
1183
+ past_key_value=past_key_value,
1184
+ output_attentions=output_attentions,
1185
+ use_cache=use_cache,
1186
+ cache_position=cache_position,
1187
+ position_embeddings=position_embeddings,
1188
+ **kwargs,
1189
+ )
1190
+
1191
+ # Coat: The Output Projection Linear Layer and Residual
1192
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states)
1193
+
1194
+ # Residual Connection, LayerNorm, and the whole MLP module
1195
+ hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual(
1196
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
1197
+ )
1198
+
1199
+ outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),)
1200
+
1201
+ if output_attentions:
1202
+ outputs += (self_attn_weights,)
1203
+
1204
+ if use_cache:
1205
+ outputs += (present_key_value,)
1206
+
1207
+ return outputs
1208
+
1209
+
1210
+ class CoatLlamaPreTrainedModel(PreTrainedModel):
1211
+ config_class = CoatLlamaConfig
1212
+ base_model_prefix = "model"
1213
+ supports_gradient_checkpointing = True
1214
+ _no_split_modules = ["LlamaDecoderLayer"]
1215
+ _skip_keys_device_placement = ["past_key_values"]
1216
+ _supports_flash_attn_2 = True
1217
+ _supports_sdpa = True
1218
+ _supports_cache_class = True
1219
+ _supports_quantized_cache = True
1220
+ _supports_static_cache = True
1221
+
1222
+ def _init_weights(self, module):
1223
+ std = self.config.initializer_range
1224
+ if isinstance(module, nn.Linear):
1225
+ module.weight.data.normal_(mean=0.0, std=std)
1226
+ if module.bias is not None:
1227
+ module.bias.data.zero_()
1228
+ elif isinstance(module, nn.Embedding):
1229
+ module.weight.data.normal_(mean=0.0, std=std)
1230
+ if module.padding_idx is not None:
1231
+ module.weight.data[module.padding_idx].zero_()
1232
+
1233
+
1234
+ class CoatLlamaModel(CoatLlamaPreTrainedModel):
1235
+ """
1236
+ Coat Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CoatLlamaDecoderLayer`]
1237
+
1238
+ Args:
1239
+ config: CoatLlamaConfig
1240
+ """
1241
+
1242
+ def __init__(self, config: CoatLlamaConfig):
1243
+ super().__init__(config)
1244
+ self.padding_idx = config.pad_token_id
1245
+ self.vocab_size = config.vocab_size
1246
+
1247
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1248
+ self.layers = nn.ModuleList(
1249
+ [CoatLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1250
+ )
1251
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1252
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
1253
+ self.gradient_checkpointing = False
1254
+
1255
+ # Quantize
1256
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
1257
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
1258
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
1259
+
1260
+ # Initialize weights and apply final processing
1261
+ self.post_init()
1262
+
1263
+ def get_input_embeddings(self):
1264
+ return self.embed_tokens
1265
+
1266
+ def set_input_embeddings(self, value):
1267
+ self.embed_tokens = value
1268
+
1269
+ def forward(
1270
+ self,
1271
+ input_ids: torch.LongTensor = None,
1272
+ attention_mask: Optional[torch.Tensor] = None,
1273
+ position_ids: Optional[torch.LongTensor] = None,
1274
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1276
+ use_cache: Optional[bool] = None,
1277
+ output_attentions: Optional[bool] = None,
1278
+ output_hidden_states: Optional[bool] = None,
1279
+ return_dict: Optional[bool] = None,
1280
+ cache_position: Optional[torch.LongTensor] = None,
1281
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1282
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1283
+ output_hidden_states = (
1284
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1285
+ )
1286
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1288
+
1289
+ if (input_ids is None) ^ (inputs_embeds is not None):
1290
+ raise ValueError(
1291
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1292
+ )
1293
+
1294
+ if self.gradient_checkpointing and self.training and use_cache:
1295
+ logger.warning_once(
1296
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1297
+ )
1298
+ use_cache = False
1299
+
1300
+ if inputs_embeds is None:
1301
+ inputs_embeds = self.embed_tokens(input_ids)
1302
+
1303
+ # kept for BC (non `Cache` `past_key_values` inputs)
1304
+ return_legacy_cache = False
1305
+ if use_cache and not isinstance(past_key_values, Cache):
1306
+ return_legacy_cache = True
1307
+ if past_key_values is None:
1308
+ past_key_values = DynamicCache()
1309
+ else:
1310
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1311
+ logger.warning_once(
1312
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
1313
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
1314
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
1315
+ )
1316
+
1317
+ if cache_position is None:
1318
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1319
+ cache_position = torch.arange(
1320
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1321
+ )
1322
+ if position_ids is None:
1323
+ position_ids = cache_position.unsqueeze(0)
1324
+
1325
+ causal_mask = self._update_causal_mask(
1326
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1327
+ )
1328
+ hidden_states = inputs_embeds
1329
+
1330
+ # create position embeddings to be shared across the decoder layers
1331
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1332
+
1333
+ # decoder layers
1334
+ all_hidden_states = () if output_hidden_states else None
1335
+ all_self_attns = () if output_attentions else None
1336
+ next_decoder_cache = None
1337
+
1338
+ # Prepare the input for Coat decoderlayer
1339
+ hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
1340
+
1341
+ for decoder_layer in self.layers:
1342
+ if output_hidden_states:
1343
+ all_hidden_states += (hidden_states,)
1344
+
1345
+ if self.gradient_checkpointing and self.training:
1346
+ layer_outputs = self._gradient_checkpointing_func(
1347
+ decoder_layer.__call__,
1348
+ hidden_states,
1349
+ quant_hidden_states,
1350
+ scale_hidden_states,
1351
+ causal_mask,
1352
+ position_ids,
1353
+ past_key_values,
1354
+ output_attentions,
1355
+ use_cache,
1356
+ cache_position,
1357
+ position_embeddings,
1358
+ )
1359
+ else:
1360
+ layer_outputs = decoder_layer(
1361
+ hidden_states,
1362
+ quant_hidden_states,
1363
+ scale_hidden_states,
1364
+ attention_mask=causal_mask,
1365
+ position_ids=position_ids,
1366
+ past_key_value=past_key_values,
1367
+ output_attentions=output_attentions,
1368
+ use_cache=use_cache,
1369
+ cache_position=cache_position,
1370
+ position_embeddings=position_embeddings,
1371
+ )
1372
+
1373
+ hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0]
1374
+
1375
+ if use_cache:
1376
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1377
+
1378
+ if output_attentions:
1379
+ all_self_attns += (layer_outputs[1],)
1380
+
1381
+ # Summarize the output of the Decoder Layer
1382
+ hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states)
1383
+
1384
+ hidden_states = self.norm(hidden_states)
1385
+
1386
+ # add hidden states from the last decoder layer
1387
+ if output_hidden_states:
1388
+ all_hidden_states += (hidden_states,)
1389
+
1390
+ next_cache = next_decoder_cache if use_cache else None
1391
+ if return_legacy_cache:
1392
+ next_cache = next_cache.to_legacy_cache()
1393
+
1394
+ if not return_dict:
1395
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1396
+ return BaseModelOutputWithPast(
1397
+ last_hidden_state=hidden_states,
1398
+ past_key_values=next_cache,
1399
+ hidden_states=all_hidden_states,
1400
+ attentions=all_self_attns,
1401
+ )
1402
+
1403
+ _update_causal_mask = LlamaModel._update_causal_mask
1404
+
1405
+
1406
+ class CoatLlamaForCausalLM(CoatLlamaPreTrainedModel, GenerationMixin):
1407
+ _tied_weights_keys = ["lm_head.weight"]
1408
+
1409
+ def __init__(self, config):
1410
+ super().__init__(config)
1411
+ self.model = CoatLlamaModel(config)
1412
+ self.vocab_size = config.vocab_size
1413
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1414
+
1415
+ # Initialize weights and apply final processing
1416
+ self.post_init()
1417
+
1418
+ def get_input_embeddings(self):
1419
+ return self.model.embed_tokens
1420
+
1421
+ def set_input_embeddings(self, value):
1422
+ self.model.embed_tokens = value
1423
+
1424
+ def get_output_embeddings(self):
1425
+ return self.lm_head
1426
+
1427
+ def set_output_embeddings(self, new_embeddings):
1428
+ self.lm_head = new_embeddings
1429
+
1430
+ def set_decoder(self, decoder):
1431
+ self.model = decoder
1432
+
1433
+ def get_decoder(self):
1434
+ return self.model
1435
+
1436
+ forward = LlamaForCausalLM.forward
1437
+
1438
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
1439
+
1440
+
1441
+ # TODO
1442
+ # class LlamaForSequenceClassification(LlamaPreTrainedModel):
1443
+
1444
+ # class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1445
+
1446
+ # class LlamaForTokenClassification(LlamaPreTrainedModel):
1447
+
1448
+
1449
+ def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
1450
+ compatible_state_dict = {}
1451
+
1452
+ for key, value in state_dict.items():
1453
+ if fnmatch(key, "*self_attn.q_proj*"):
1454
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
1455
+ elif fnmatch(key, "*self_attn.k_proj*"):
1456
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
1457
+ elif fnmatch(key, "*self_attn.v_proj*"):
1458
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
1459
+ elif fnmatch(key, "*self_attn.o_proj*"):
1460
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
1461
+
1462
+ elif fnmatch(key, "*mlp.gate_proj*"):
1463
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
1464
+ elif fnmatch(key, "*mlp.up_proj*"):
1465
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
1466
+ elif fnmatch(key, "*mlp.down_proj*"):
1467
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
1468
+
1469
+ else:
1470
+ new_key = key
1471
+
1472
+ compatible_state_dict[new_key] = value
1473
+
1474
+ return compatible_state_dict
1475
+
1476
+
1477
+ AutoConfig.register("fp8_llama", CoatLlamaConfig)
1478
+ AutoModel.register(CoatLlamaConfig, CoatLlamaModel)
1479
+ AutoModelForCausalLM.register(CoatLlamaConfig, CoatLlamaForCausalLM)
llava/model/coat/activation/models/coat_llama_convert_from_hf.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import argparse
8
+ import os
9
+ from dataclasses import asdict, dataclass, field
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import transformers
14
+ from coat.activation.models._fp8_quantization_config import QuantizationConfig
15
+ from coat.activation.models.coat_llama import CoatLlamaConfig, CoatLlamaForCausalLM, make_state_dict_compatible
16
+ from transformers import AutoConfig, AutoModelForCausalLM
17
+
18
+
19
+ @dataclass
20
+ class ConvertArguments:
21
+ model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"})
22
+ save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"})
23
+ cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"})
24
+
25
+
26
+ def download_and_convert_llama(convert_args: ConvertArguments, quantization_args: QuantizationConfig):
27
+ """
28
+ Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`,
29
+ and saves the converted model.
30
+
31
+ Args:
32
+ model_name (str): The model name or path to download the LLaMA model.
33
+ save_path (str): The path where the converted model weights will be saved.
34
+ cache_dir (Optional[str]): Directory to cache the model. Defaults to None.
35
+
36
+ Returns:
37
+ None
38
+ """
39
+ model_name = convert_args.model_name
40
+ save_path = convert_args.save_path
41
+ cache_dir = convert_args.cache_dir
42
+
43
+ # Step 1: Download the original LLaMA model
44
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
45
+
46
+ # Step 2: Initialize the model configuration for FP8 or other custom config
47
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
48
+
49
+ # Step 3: Apply make_state_dict_compatible to convert weights
50
+ compatible_state_dict = make_state_dict_compatible(model.state_dict())
51
+
52
+ # Step 4: Create a new model instance with compatible configuration
53
+ fp8_config = CoatLlamaConfig(**config.to_dict())
54
+ fp8_config.coat_fp8_args = asdict(quantization_args)
55
+
56
+ converted_model = AutoModelForCausalLM.from_config(fp8_config)
57
+ converted_model.load_state_dict(compatible_state_dict)
58
+
59
+ # Step 5: Save the converted model and configuration using save_pretrained
60
+ os.makedirs(save_path, exist_ok=True)
61
+ converted_model.save_pretrained(save_path)
62
+ print(f"Converted model saved at {save_path}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ # Parse command-line arguments
67
+ parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8
68
+ convert_args, quantization_args = parser.parse_args_into_dataclasses()
69
+
70
+ # Call the function with parsed arguments
71
+ download_and_convert_llama(convert_args, quantization_args)
llava/model/coat/activation/models/coat_olmo.py ADDED
@@ -0,0 +1,1942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ """
24
+ Adapted from
25
+ [MosaiclML](https://github.com/mosaicml/examples.git) and
26
+ [minGPT](https://github.com/karpathy/minGPT.git)
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import logging
32
+ import math
33
+ import sys
34
+ from abc import abstractmethod
35
+ from collections import defaultdict
36
+ from functools import partial
37
+ from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple, cast
38
+
39
+ import torch
40
+ import torch.backends.cuda
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ from olmo.aliases import PathOrStr
44
+ from olmo.beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
45
+ from olmo.config import (
46
+ ActivationCheckpointingStrategy,
47
+ ActivationType,
48
+ BlockType,
49
+ CheckpointType,
50
+ FSDPWrapStrategy,
51
+ InitFnType,
52
+ LayerNormType,
53
+ ModelConfig,
54
+ QuantActivationConfig,
55
+ ShardedCheckpointerType,
56
+ TrainConfig,
57
+ )
58
+ from olmo.exceptions import OLMoConfigurationError
59
+ from olmo.initialization import init_normal
60
+ from olmo.model import (
61
+ Activation,
62
+ BufferCache,
63
+ Dropout,
64
+ LayerNorm,
65
+ LayerNormBase,
66
+ OLMo,
67
+ OLMoBlock,
68
+ OLMoBlockGroup,
69
+ OLMoGenerateOutput,
70
+ OLMoOutput,
71
+ RMSLayerNorm,
72
+ RotaryEmbedding,
73
+ _non_meta_init_device,
74
+ activation_checkpoint_function,
75
+ alibi_attention_bias,
76
+ causal_attention_bias,
77
+ get_causal_attention_bias,
78
+ should_checkpoint_block,
79
+ )
80
+ from olmo.torch_util import ensure_finite_, get_cumulative_document_lengths
81
+ from torch import einsum
82
+
83
+ from ..real_quantization import (
84
+ Coat_quantize_bgn,
85
+ Coat_quantize_end,
86
+ fp8_add_Ifp_Ifp_Ofp_Og16,
87
+ fp8_add_Ifp_Ifp_Ofp_Opt,
88
+ fp8_division,
89
+ fp8_division_transpose,
90
+ fp8_gelu_backward,
91
+ fp8_gelu_forward,
92
+ fp8_layernorm_noparam_backward,
93
+ fp8_layernorm_noparam_forward,
94
+ fp8_linear_backward,
95
+ fp8_linear_forward,
96
+ fp8_mul_backward,
97
+ fp8_mul_forward,
98
+ fp8_quantize,
99
+ fp8_quantize_pertensor,
100
+ fp8_quantize_pertensor_transpose,
101
+ fp8_rmsnorm_backward,
102
+ fp8_rmsnorm_forward,
103
+ fp8_silu_backward,
104
+ fp8_silu_forward,
105
+ fp8_transpose,
106
+ )
107
+ from ._fp8_weightcache import FP8CacheWeightModule
108
+ from ._fp8manager import FP8Manager
109
+
110
+ if sys.version_info.minor > 8:
111
+ from collections.abc import MutableMapping
112
+ elif sys.version_info.minor == 8:
113
+ from typing import MutableMapping
114
+ else:
115
+ raise SystemExit("This script supports Python 3.8 or higher")
116
+
117
+ __all__ = [
118
+ "LayerNormBase",
119
+ "LayerNorm",
120
+ "RMSLayerNorm",
121
+ "RotaryEmbedding",
122
+ "Activation",
123
+ "GELU",
124
+ "ReLU",
125
+ "SwiGLU",
126
+ "OLMoBlock",
127
+ "OLMoSequentialBlock",
128
+ "OLMo",
129
+ "OLMoOutput",
130
+ "OLMoGenerateOutput",
131
+ ]
132
+
133
+
134
+ log = logging.getLogger(__name__)
135
+
136
+
137
+ class CoatOLMoBeforeAttentionResidual(FP8CacheWeightModule):
138
+ """
139
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
140
+ """
141
+
142
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, fused_dims: tuple):
143
+ super().__init__(config, qargs, layer_id)
144
+
145
+ self.qargs = qargs
146
+ self.fwobits = {
147
+ "fabit": self.qargs.fabit,
148
+ "fwbit": self.qargs.fwbit,
149
+ "fobit": self.qargs.fobit,
150
+ "babit": self.qargs.babit,
151
+ "bwbit": self.qargs.bwbit,
152
+ "bobit": self.qargs.bobit,
153
+ }
154
+ self.ln_normalized_shape = config.d_model
155
+ self.att_proj = nn.Linear(config.d_model, sum(fused_dims), bias=config.include_bias, device=config.init_device)
156
+
157
+ self.attn_norm = LayerNorm.build(config)
158
+
159
+ def forward(self, re_x, x, s):
160
+ if self.training:
161
+ if self.qargs.weight_memory_efficient:
162
+ # Prepare
163
+ with torch.no_grad():
164
+ weight1_s = self.prepare_weight(self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch)
165
+ return _CoatOLMoBeforeAttentionResidual.apply(
166
+ re_x,
167
+ x,
168
+ s,
169
+ self.att_proj.weight,
170
+ None,
171
+ None,
172
+ weight1_s,
173
+ self.qargs.group_size,
174
+ self.fwobits,
175
+ self.layer_id,
176
+ self.config,
177
+ self.qargs,
178
+ )
179
+ else:
180
+ # Prepare
181
+ with torch.no_grad():
182
+ weight1, weight1_t, weight1_s = self.prepare_weight(
183
+ self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch
184
+ )
185
+ return _CoatOLMoBeforeAttentionResidual.apply(
186
+ re_x,
187
+ x,
188
+ s,
189
+ self.att_proj.weight,
190
+ weight1,
191
+ weight1_t,
192
+ weight1_s,
193
+ self.qargs.group_size,
194
+ self.fwobits,
195
+ self.layer_id,
196
+ self.config,
197
+ self.qargs,
198
+ )
199
+ else:
200
+ return re_x, self.att_proj(self.attn_norm(re_x))
201
+
202
+
203
+ class _CoatOLMoBeforeAttentionResidual(torch.autograd.Function):
204
+ @staticmethod
205
+ def forward(
206
+ ctx,
207
+ re_x,
208
+ in_x,
209
+ in_s,
210
+ weight1_origin,
211
+ weight1,
212
+ weight1_t,
213
+ weight1_s,
214
+ group_size,
215
+ fwobits,
216
+ layer_id,
217
+ config,
218
+ qargs,
219
+ eps=1e-5,
220
+ ):
221
+ # for autograd
222
+ if fwobits["fabit"] == "E4M3":
223
+ # in_x = in_x.to(torch.float8_e4m3fn)
224
+ in_x = in_x.view(torch.float8_e4m3fn)
225
+ else:
226
+ raise ValueError("fabit should be E4M3")
227
+
228
+ # LayerNorm
229
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
230
+ in_x, in_s, group_size, eps, transpose_output_2d=True
231
+ )
232
+
233
+ # Linear Layer QKV Projection
234
+ if qargs.weight_memory_efficient:
235
+ assert weight1 is None # memory efficient
236
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
237
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size)
238
+
239
+ # ==================== save for backward ====================
240
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
241
+ if qargs.weight_memory_efficient:
242
+ assert weight1_t is None
243
+ ctx.weight = weight1_origin, weight1_s
244
+ else:
245
+ ctx.weight = weight1_t, weight1_s
246
+ ctx.group_size = group_size
247
+ ctx.ln_utils = ln_utils
248
+ ctx.utils = fwobits, layer_id, config, qargs
249
+
250
+ return re_x, fc1_x
251
+
252
+ @staticmethod
253
+ def backward(ctx, fp_grad, flash_g):
254
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
255
+ weight1_t, weight1_s = ctx.weight
256
+ group_size = ctx.group_size
257
+ mean, rstd, num_warps = ctx.ln_utils
258
+ fwobits, layer_id, config, qargs = ctx.utils
259
+
260
+ # ==================== Begin backward ====================
261
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
262
+ flash_g, flash_gs, flash_g_t = fp8_quantize_pertensor_transpose(
263
+ flash_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
264
+ )
265
+
266
+ # Linear Layer QKV Projection
267
+ if qargs.weight_memory_efficient:
268
+ weight1_t, weight1_s = fp8_division_transpose(
269
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
270
+ )
271
+ fc1_g, att_proj_wg = fp8_linear_backward(
272
+ ln_x_t, ln_s, flash_g, flash_gs, flash_g_t, weight1_t, weight1_s, group_size
273
+ )
274
+
275
+ # LayerNorm
276
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
277
+
278
+ # Add the gradient together, and prepare the input of the next layer.
279
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
280
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
281
+ )
282
+
283
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
284
+ in_g = in_g.view(torch.float8_e4m3fn)
285
+
286
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
287
+ return re_g, in_g, in_sg_g16, att_proj_wg, None, None, None, None, None, None, None, None, None
288
+
289
+
290
+ class CoatOLMoAfterAttentionResidual(FP8CacheWeightModule):
291
+ """
292
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
293
+ """
294
+
295
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id):
296
+ super().__init__(config, qargs, layer_id)
297
+
298
+ self.qargs = qargs
299
+ self.fwobits = {
300
+ "fabit": self.qargs.fabit,
301
+ "fwbit": self.qargs.fwbit,
302
+ "fobit": self.qargs.fobit,
303
+ "babit": self.qargs.babit,
304
+ "bwbit": self.qargs.bwbit,
305
+ "bobit": self.qargs.bobit,
306
+ }
307
+ self.attn_out = nn.Linear(config.d_model, config.d_model, bias=config.include_bias, device=config.init_device)
308
+
309
+ def forward(self, re_x, in_x):
310
+ if self.training:
311
+ if self.qargs.weight_memory_efficient:
312
+ # prepare for the weight
313
+ with torch.no_grad():
314
+ weight2_s = self.prepare_weight(self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch)
315
+
316
+ return _CoatOLMoAfterAttentionResidual.apply(
317
+ re_x,
318
+ in_x,
319
+ self.attn_out.weight,
320
+ None,
321
+ None,
322
+ weight2_s,
323
+ self.qargs.group_size,
324
+ self.fwobits,
325
+ self.layer_id,
326
+ self.config,
327
+ self.qargs,
328
+ )
329
+ else:
330
+ # prepare for the weight
331
+ with torch.no_grad():
332
+ weight2, weight2_t, weight2_s = self.prepare_weight(
333
+ self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch
334
+ )
335
+
336
+ return _CoatOLMoAfterAttentionResidual.apply(
337
+ re_x,
338
+ in_x,
339
+ self.attn_out.weight,
340
+ weight2,
341
+ weight2_t,
342
+ weight2_s,
343
+ self.qargs.group_size,
344
+ self.fwobits,
345
+ self.layer_id,
346
+ self.config,
347
+ self.qargs,
348
+ )
349
+ else:
350
+ return re_x + self.attn_out(in_x), None, None
351
+
352
+
353
+ class _CoatOLMoAfterAttentionResidual(torch.autograd.Function):
354
+ @staticmethod
355
+ def forward(
356
+ ctx, re_x, flash_x, weight2_origin, weight2, weight2_t, weight2_s, group_size, fwobits, layer_id, config, qargs
357
+ ):
358
+ # Quantize the FlashAttention Output
359
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
360
+ flash_x, group_size, fwobits["fabit"]
361
+ ) # Modified to make it memory efficient
362
+
363
+ # # Attention Projection Linear Layer
364
+ if qargs.weight_memory_efficient:
365
+ assert weight2 is None # memory efficient
366
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
367
+ fc2_x = fp8_linear_forward(flash_qx, flash_s, weight2, weight2_s, False, group_size) #
368
+
369
+ # import IPython
370
+ # IPython.embed()
371
+ # Add the activations together
372
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, flash_qx.dtype, group_size)
373
+
374
+ # ==================== save for backward ====================
375
+ ctx.save_for_backward(flash_x, flash_s)
376
+ if qargs.weight_memory_efficient:
377
+ assert weight2_t is None
378
+ ctx.weight = weight2_origin, weight2_s
379
+ else:
380
+ ctx.weight = weight2_t, weight2_s
381
+ ctx.group_size = group_size
382
+ ctx.fwobits = fwobits
383
+ ctx.utils = fwobits, layer_id, config, qargs
384
+
385
+ # For autograd
386
+ out_x = out_x.view(torch.float8_e4m3fn)
387
+
388
+ return fp_x, out_x, out_s
389
+
390
+ @staticmethod
391
+ def backward(ctx, fp_grad, out_g, out_gs):
392
+ flash_x, flash_s = ctx.saved_tensors
393
+ weight2_t, weight2_s = ctx.weight
394
+ group_size = ctx.group_size
395
+ fwobits = ctx.fwobits
396
+ fwobits, layer_id, config, qargs = ctx.utils
397
+
398
+ # for autograd
399
+ if fwobits["babit"] == "E5M2":
400
+ # out_g = out_g.to(torch.float8_e5m2)
401
+ out_g = out_g.view(torch.float8_e5m2)
402
+ else:
403
+ raise ValueError("babit should be E5M2")
404
+ out_gs_max = out_gs.max()
405
+
406
+ # ==================== Begin backward ====================
407
+ # Output Projection
408
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
409
+
410
+ # We do not save an extra flash_x to save the memory usage
411
+ flash_x_t, flash_s = fp8_division_transpose(
412
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
413
+ )
414
+
415
+ if qargs.weight_memory_efficient:
416
+ weight2_t, weight2_s = fp8_division_transpose(
417
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
418
+ )
419
+ fc2_g, attn_out_wg = fp8_linear_backward(
420
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
421
+ )
422
+
423
+ return fp_grad, fc2_g, attn_out_wg, None, None, None, None, None, None, None, None
424
+
425
+
426
+ class CoatOLMoMLPResidual(FP8CacheWeightModule):
427
+ """
428
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
429
+ (4) GELU / Silu Activation
430
+ """
431
+
432
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, hidden_size: int):
433
+ super().__init__(config, qargs, layer_id)
434
+
435
+ self.qargs = qargs
436
+ self.fwobits = {
437
+ "fabit": self.qargs.fabit,
438
+ "fwbit": self.qargs.fwbit,
439
+ "fobit": self.qargs.fobit,
440
+ "babit": self.qargs.babit,
441
+ "bwbit": self.qargs.bwbit,
442
+ "bobit": self.qargs.bobit,
443
+ }
444
+ self.ln_normalized_shape = config.d_model
445
+ self.act_output_multiplier = 0.5 if config.activation_type == ActivationType.swiglu else 1
446
+ self.ff_proj = nn.Linear(config.d_model, hidden_size, bias=config.include_bias, device=config.init_device)
447
+ self.ff_out = nn.Linear(
448
+ int(self.act_output_multiplier * hidden_size),
449
+ config.d_model,
450
+ bias=config.include_bias,
451
+ device=config.init_device,
452
+ )
453
+ self.training = True
454
+
455
+ # below is only used when training = False
456
+ self.ff_norm = LayerNorm.build(config)
457
+ self.act = Activation.build(config)
458
+ assert (self.act.output_multiplier * hidden_size) % 1 == 0
459
+
460
+ def forward(self, re_x, x, s):
461
+ if self.training:
462
+ if self.qargs.weight_memory_efficient: # prepare for the weight
463
+ with torch.no_grad():
464
+ weight1_s = self.prepare_weight(self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch)
465
+ weight2_s = self.prepare_weight(self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch)
466
+
467
+ return _CoatOLMoMLPResidual.apply(
468
+ re_x,
469
+ x,
470
+ s,
471
+ self.ff_proj.weight,
472
+ None,
473
+ None,
474
+ weight1_s,
475
+ self.ff_out.weight,
476
+ None,
477
+ None,
478
+ weight2_s,
479
+ self.qargs.group_size,
480
+ self.fwobits,
481
+ self.layer_id,
482
+ self.config,
483
+ self.qargs,
484
+ )
485
+ else:
486
+ # prepare for the weight
487
+ with torch.no_grad():
488
+ weight1, weight1_t, weight1_s = self.prepare_weight(
489
+ self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch
490
+ )
491
+ weight2, weight2_t, weight2_s = self.prepare_weight(
492
+ self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch
493
+ )
494
+
495
+ return _CoatOLMoMLPResidual.apply(
496
+ re_x,
497
+ x,
498
+ s,
499
+ self.ff_proj.weight,
500
+ weight1,
501
+ weight1_t,
502
+ weight1_s,
503
+ self.ff_out.weight,
504
+ weight2,
505
+ weight2_t,
506
+ weight2_s,
507
+ self.qargs.group_size,
508
+ self.fwobits,
509
+ self.layer_id,
510
+ self.config,
511
+ self.qargs,
512
+ )
513
+ else:
514
+ og_x = re_x
515
+ re_x = self.ff_norm(re_x)
516
+ re_x = self.ff_proj(re_x)
517
+ re_x = self.act(re_x)
518
+ re_x = self.ff_out(re_x)
519
+ re_x = og_x + re_x
520
+ return re_x, None, None
521
+
522
+
523
+ class _CoatOLMoMLPResidual(torch.autograd.Function):
524
+ @staticmethod
525
+ def forward(
526
+ ctx,
527
+ re_x,
528
+ in_x,
529
+ in_s,
530
+ weight1_origin,
531
+ weight1,
532
+ weight1_t,
533
+ weight1_s,
534
+ weight2_origin,
535
+ weight2,
536
+ weight2_t,
537
+ weight2_s,
538
+ group_size,
539
+ fwobits,
540
+ layer_id,
541
+ config,
542
+ qargs,
543
+ eps=1e-5,
544
+ ):
545
+ # For autograd
546
+ if fwobits["fabit"] == "E4M3":
547
+ # in_x = in_x.to(torch.float8_e4m3fn)
548
+ in_x = in_x.view(torch.float8_e4m3fn)
549
+ else:
550
+ raise ValueError("fabit should be E4M3")
551
+
552
+ # LayerNorm
553
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
554
+ in_x, in_s, group_size, eps, transpose_output_2d=True
555
+ )
556
+
557
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
558
+ if qargs.weight_memory_efficient:
559
+ assert weight1 is None # memory efficient
560
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
561
+ fc1_x, fc1_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size)
562
+
563
+ # NOTE: Becareful of the order
564
+ up_x, gate_x = fc1_x.chunk(2, dim=-1)
565
+ up_s, gate_s = fc1_s.chunk(2, dim=-1)
566
+
567
+ # silu Activation
568
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
569
+
570
+ # Element-wise Multiplication
571
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
572
+
573
+ # Output Projection
574
+ if weight2 is None: # memory efficient
575
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
576
+ fc2_x = fp8_linear_forward(mul_x, mul_s, weight2, weight2_s, False, group_size)
577
+
578
+ # Add the activation together
579
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, mul_x.dtype, group_size)
580
+
581
+ # ==================== save for backward ====================
582
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
583
+
584
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
585
+ if (
586
+ qargs.weight_memory_efficient
587
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
588
+ assert weight1_t is None
589
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s)
590
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
591
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
592
+
593
+ ctx.group_size = group_size
594
+ ctx.ln_utils = ln_utils
595
+ ctx.utils = fwobits, layer_id, config, qargs
596
+
597
+ out_x = out_x.view(torch.float8_e4m3fn)
598
+
599
+ return fp_x, out_x, out_s
600
+
601
+ @staticmethod
602
+ def backward(ctx, fp_grad, out_g, out_gs):
603
+ fwobits, layer_id, config, qargs = ctx.utils
604
+
605
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
606
+
607
+ (weight1_t, weight1_s, weight2_t, weight2_s) = ctx.weight
608
+ group_size = ctx.group_size
609
+ mean, rstd, num_warps = ctx.ln_utils
610
+ fwobits, layer_id, config, qargs = ctx.utils
611
+
612
+ # For autograd
613
+ if fwobits["babit"] == "E5M2":
614
+ # out_g = out_g.to(torch.float8_e5m2)
615
+ out_g = out_g.view(torch.float8_e5m2)
616
+ else:
617
+ raise ValueError("babit should be E5M2")
618
+ out_gs_max = out_gs.max()
619
+
620
+ # ==================== Begin backward ====================
621
+ # Output Projection
622
+ out_gs = out_gs.max()
623
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
624
+
625
+ if qargs.weight_memory_efficient:
626
+ weight2_t, weight2_s = fp8_division_transpose(
627
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
628
+ )
629
+ fc2_g, weight2_grad = fp8_linear_backward(
630
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
631
+ )
632
+
633
+ # [MEM TEST]
634
+ del out_g, out_g_t, weight2_t
635
+
636
+ # Element-wise Multiplication, 1 means gate, 2 means up
637
+ mul_g1, (mul_g2, mul_gs2) = fp8_mul_backward(silu_x, silu_s, up_x, up_s, fc2_g, group_size, fwobits["babit"])
638
+
639
+ # Silu activation
640
+ silu_g, silu_gs = fp8_silu_backward(gate_x, gate_s, mul_g1, group_size, fwobits["babit"])
641
+
642
+ # Prepare the input of Linear Layer. NOTE: Becareful of the order
643
+ gateup_g = torch.cat([mul_g2, silu_g], dim=-1)
644
+ gateup_gs = torch.cat([mul_gs2, silu_gs])
645
+ gateup_gs = torch.max(gateup_gs)
646
+
647
+ gateup_g, gateup_gs, gateup_g_t = fp8_division_transpose(
648
+ gateup_g, group_size, fwobits["babit"], gateup_gs, stochastic=False
649
+ )
650
+
651
+ # Linear Layer of Up and Gate Projection
652
+ if qargs.weight_memory_efficient:
653
+ weight1_t, weight1_s = fp8_division_transpose(
654
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
655
+ )
656
+ fc1_g, weight1_grad = fp8_linear_backward(
657
+ ln_x_t, ln_s, gateup_g, gateup_gs, gateup_g_t, weight1_t, weight1_s, group_size
658
+ )
659
+
660
+ # layerNorm
661
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
662
+
663
+ # Add the gradient together
664
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
665
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
666
+ )
667
+
668
+ in_g = in_g.view(torch.float8_e4m3fn)
669
+
670
+ return (
671
+ re_g,
672
+ in_g,
673
+ in_sg_g16,
674
+ weight1_grad,
675
+ None,
676
+ None,
677
+ None,
678
+ weight2_grad,
679
+ None,
680
+ None,
681
+ None,
682
+ None,
683
+ None,
684
+ None,
685
+ None,
686
+ None,
687
+ None,
688
+ )
689
+
690
+
691
+ class CoatOLMoBlock(nn.Module):
692
+ """
693
+ A base class for transformer block implementations.
694
+ """
695
+
696
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
697
+ super().__init__()
698
+ self.layer_id = layer_id
699
+ self.config = config
700
+ self.qargs = qargs
701
+ self.hidden_size = (
702
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
703
+ )
704
+ self.__cache = cache
705
+ assert config.d_model % config.n_heads == 0
706
+
707
+ self._activation_checkpoint_fn: Callable | None = None
708
+
709
+ # Dropout.
710
+ self.dropout = Dropout(config.residual_dropout)
711
+
712
+ # Layer norms.
713
+ self.k_norm: LayerNormBase | None = None
714
+ self.q_norm: LayerNormBase | None = None
715
+ if config.attention_layer_norm:
716
+ assert config.effective_n_kv_heads is not None
717
+ self.k_norm = LayerNormBase.build(
718
+ config,
719
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
720
+ elementwise_affine=config.attention_layer_norm_with_affine,
721
+ )
722
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
723
+
724
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
725
+ if config.clip_qkv is not None:
726
+ assert config.clip_qkv > 0
727
+
728
+ # Activation function.
729
+ self.act = Activation.build(config)
730
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
731
+
732
+ if not self.qargs.use_quantize_model:
733
+ # Attention output projection.
734
+ self.attn_out = nn.Linear(
735
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
736
+ )
737
+
738
+ # Feed-forward output projection.
739
+ self.ff_out = nn.Linear(
740
+ int(self.act.output_multiplier * self.hidden_size),
741
+ config.d_model,
742
+ bias=config.include_bias,
743
+ device=config.init_device,
744
+ )
745
+ self.ff_out._is_residual = True # type: ignore
746
+
747
+ # Rotary embeddings.
748
+ if self.config.rope:
749
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
750
+
751
+ self.flash_attn_func = None
752
+ self.flash_attn_varlen_func = None
753
+ if config.flash_attention:
754
+ try:
755
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
756
+
757
+ self.flash_attn_func = flash_attn_func
758
+ self.flash_attn_varlen_func = flash_attn_varlen_func
759
+ except ModuleNotFoundError:
760
+ pass
761
+
762
+ def reset_parameters(self):
763
+ if self.k_norm is not None:
764
+ self.k_norm.reset_parameters()
765
+ if self.q_norm is not None:
766
+ self.q_norm.reset_parameters()
767
+
768
+ if not self.qargs.use_quantize_model:
769
+ if self.config.init_fn == InitFnType.normal:
770
+ attn_out_std = ff_out_std = self.config.init_std
771
+ cutoff_factor = self.config.init_cutoff_factor
772
+
773
+ elif self.config.init_fn == InitFnType.mitchell:
774
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
775
+ ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
776
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
777
+
778
+ elif self.config.init_fn == InitFnType.full_megatron:
779
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
780
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
781
+
782
+ else:
783
+ raise NotImplementedError(self.config.init_fn)
784
+
785
+ init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
786
+ init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
787
+
788
+ def set_activation_checkpointing(
789
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
790
+ ):
791
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
792
+ self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
793
+ else:
794
+ self._activation_checkpoint_fn = None
795
+
796
+ @classmethod
797
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
798
+ target_dtype = input_dtype
799
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
800
+ # `is_autocast_cpu_enabled()` for CPU autocast.
801
+ # See https://github.com/pytorch/pytorch/issues/110966.
802
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
803
+ target_dtype = torch.get_autocast_gpu_dtype()
804
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
805
+ target_dtype = torch.get_autocast_cpu_dtype()
806
+ if bias.dtype != target_dtype:
807
+ bias = bias.to(target_dtype)
808
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
809
+ return bias
810
+
811
+ def _scaled_dot_product_attention(
812
+ self,
813
+ q: torch.Tensor,
814
+ k: torch.Tensor,
815
+ v: torch.Tensor,
816
+ attn_mask: torch.Tensor | None = None,
817
+ dropout_p: float = 0.0,
818
+ is_causal: bool = False,
819
+ max_doc_len: int | None = None,
820
+ cu_doc_lens: torch.Tensor | None = None,
821
+ ) -> torch.Tensor:
822
+ """
823
+ Computes scaled dot product attention on query, key and value tensors, using an optional
824
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
825
+ """
826
+ if max_doc_len is not None and cu_doc_lens is not None:
827
+ assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
828
+ assert attn_mask is None, "attn-mask is currently not supported with document masking"
829
+ B, T, D = q.size(0), q.size(2), q.size(3)
830
+ r = self.flash_attn_varlen_func(
831
+ q.transpose(1, 2).view(B * T, -1, D),
832
+ k.transpose(1, 2).view(B * T, -1, D),
833
+ v.transpose(1, 2).view(B * T, -1, D),
834
+ cu_doc_lens,
835
+ cu_doc_lens,
836
+ max_doc_len,
837
+ max_doc_len,
838
+ dropout_p=dropout_p,
839
+ causal=is_causal,
840
+ )
841
+ return r.view(B, T, -1, D).transpose(1, 2)
842
+ elif self.flash_attn_func is not None and attn_mask is None:
843
+ r = self.flash_attn_func(
844
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
845
+ )
846
+ return r.transpose(1, 2)
847
+ else:
848
+ # torch's sdpa doesn't support GQA, so we're doing this
849
+ assert k.size(1) == v.size(1)
850
+ num_kv_heads = k.size(1)
851
+ num_q_heads = q.size(1)
852
+ if num_q_heads != num_kv_heads:
853
+ assert num_q_heads % num_kv_heads == 0
854
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
855
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
856
+
857
+ return F.scaled_dot_product_attention(
858
+ q,
859
+ k,
860
+ v,
861
+ attn_mask=attn_mask,
862
+ dropout_p=dropout_p,
863
+ is_causal=is_causal,
864
+ )
865
+
866
+ def attention(
867
+ self,
868
+ q: torch.Tensor,
869
+ k: torch.Tensor,
870
+ v: torch.Tensor,
871
+ attention_bias: torch.Tensor | None = None,
872
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
873
+ use_cache: bool = False,
874
+ max_doc_len: int | None = None,
875
+ cu_doc_lens: torch.Tensor | None = None,
876
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
877
+ B, T, C = q.size() # batch size, sequence length, d_model
878
+ dtype = k.dtype
879
+
880
+ # Optionally apply layer norm to keys and queries.
881
+ if self.q_norm is not None and self.k_norm is not None:
882
+ q = self.q_norm(q).to(dtype=dtype)
883
+ k = self.k_norm(k).to(dtype=dtype)
884
+
885
+ # Move head forward to be next to the batch dim.
886
+ # shape: (B, nh, T, hs)
887
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
888
+ # shape: (B, n_kv_h, T, hs)
889
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
890
+ # shape: (B, n_kv_h, T, hs)
891
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
892
+
893
+ if layer_past is not None:
894
+ past_key, past_value = layer_past
895
+ k = torch.cat((past_key, k), dim=-2)
896
+ v = torch.cat((past_value, v), dim=-2)
897
+
898
+ present = (k, v) if use_cache else None
899
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
900
+
901
+ if self.config.rope:
902
+ # Apply rotary embeddings.
903
+ q, k = self.rotary_emb(q, k)
904
+
905
+ if attention_bias is not None:
906
+ # Resize and cast attention bias.
907
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
908
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
909
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
910
+ # cause the SDP attn function to produce NaNs.
911
+ attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
912
+
913
+ # Get the attention scores.
914
+ # shape: (B, nh, T, hs)
915
+ att = self._scaled_dot_product_attention(
916
+ q,
917
+ k,
918
+ v,
919
+ attn_mask=attention_bias,
920
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
921
+ is_causal=attention_bias is None,
922
+ max_doc_len=max_doc_len,
923
+ cu_doc_lens=cu_doc_lens,
924
+ )
925
+
926
+ # Re-assemble all head outputs side-by-side.
927
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
928
+
929
+ # Apply output projection. NOTE: We move the attn output outside of this attention function
930
+ return att, present
931
+
932
+ @abstractmethod
933
+ def forward(
934
+ self,
935
+ x: torch.Tensor,
936
+ attention_bias: torch.FloatTensor | None = None,
937
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
938
+ use_cache: bool = False,
939
+ max_doc_len: int | None = None,
940
+ cu_doc_lens: torch.Tensor | None = None,
941
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
942
+ raise NotImplementedError
943
+
944
+ @classmethod
945
+ def build(cls, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache) -> OLMoBlock:
946
+ if config.block_type == BlockType.sequential:
947
+ return CoatOLMoSequentialBlock(layer_id, config, qargs, cache)
948
+ elif config.block_type == BlockType.llama:
949
+ return CoatOLMoLlamaBlock(layer_id, config, qargs, cache)
950
+ else:
951
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
952
+
953
+
954
+ class CoatOLMoSequentialBlock(CoatOLMoBlock):
955
+ """
956
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
957
+ (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``,
958
+ use the flag `norm_after`.
959
+ """
960
+
961
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
962
+ super().__init__(layer_id, config, qargs, cache)
963
+ # Attention input projection. Projects x -> (q, k, v)
964
+
965
+ assert not self.config.norm_after, "COAT currently does not support PostNorm"
966
+
967
+ head_dim = config.d_model // config.n_heads
968
+ self.fused_dims = (
969
+ config.d_model,
970
+ config.effective_n_kv_heads * head_dim,
971
+ config.effective_n_kv_heads * head_dim,
972
+ )
973
+
974
+ if self.qargs.use_quantize_model:
975
+ self.BeforeAttention = CoatOLMoBeforeAttentionResidual(config, qargs, self.layer_id, self.fused_dims)
976
+ self.AfterAttention = CoatOLMoAfterAttentionResidual(config, qargs, self.layer_id)
977
+ self.MLPResidual = CoatOLMoMLPResidual(config, qargs, self.layer_id, self.hidden_size)
978
+ else:
979
+ self.att_proj = nn.Linear(
980
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
981
+ )
982
+ # Feed-forward input projection.
983
+ self.ff_proj = nn.Linear(
984
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
985
+ )
986
+
987
+ # Layer norms.
988
+ self.attn_norm = LayerNorm.build(config, size=config.d_model)
989
+ self.ff_norm = LayerNorm.build(config, size=config.d_model)
990
+
991
+ def reset_parameters(self):
992
+ super().reset_parameters()
993
+ self.attn_norm.reset_parameters()
994
+ self.ff_norm.reset_parameters()
995
+ # NOTE: the standard deviation for these weights does not depend on the layer.
996
+
997
+ if self.qargs.use_quantize_model: # The initialization appears here, not in CoatOLMoBlock's reset_parameters
998
+ if self.config.init_fn == InitFnType.normal:
999
+ attn_out_std = ff_out_std = self.config.init_std
1000
+ cutoff_factor = self.config.init_cutoff_factor
1001
+
1002
+ elif self.config.init_fn == InitFnType.mitchell:
1003
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
1004
+ ff_out_std = 1 / (math.sqrt(2 * self.MLPResidual.ff_out.in_features * (self.layer_id + 1)))
1005
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1006
+
1007
+ elif self.config.init_fn == InitFnType.full_megatron:
1008
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
1009
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1010
+
1011
+ else:
1012
+ raise NotImplementedError(self.config.init_fn)
1013
+
1014
+ init_normal(self.AfterAttention.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
1015
+ init_normal(self.MLPResidual.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
1016
+
1017
+ if self.config.init_fn == InitFnType.normal:
1018
+ std = self.config.init_std
1019
+ cutoff_factor = self.config.init_cutoff_factor
1020
+ elif self.config.init_fn == InitFnType.mitchell:
1021
+ std = 1 / math.sqrt(self.config.d_model)
1022
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1023
+ elif self.config.init_fn == InitFnType.full_megatron:
1024
+ std = self.config.init_std
1025
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1026
+ else:
1027
+ raise NotImplementedError(self.config.init_fn)
1028
+
1029
+ if not self.qargs.use_quantize_model:
1030
+ init_normal(self.att_proj, std, cutoff_factor)
1031
+ init_normal(self.ff_proj, std, cutoff_factor)
1032
+ else:
1033
+ init_normal(self.BeforeAttention.att_proj, std, cutoff_factor)
1034
+ init_normal(self.MLPResidual.ff_proj, std, cutoff_factor)
1035
+
1036
+ def forward(
1037
+ self,
1038
+ x: torch.Tensor,
1039
+ qx: torch.Tensor,
1040
+ sx: torch.Tensor,
1041
+ attention_bias: torch.Tensor | None = None,
1042
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
1043
+ use_cache: bool = False,
1044
+ max_doc_len: int | None = None,
1045
+ cu_doc_lens: torch.Tensor | None = None,
1046
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
1047
+ # Get query, key, value projections.
1048
+ # shape:
1049
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
1050
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
1051
+ # k, v: (batch_size, seq_len, d_model // n_heads)
1052
+ # - for group query attn q: (batch_size, seq_len, d_model)
1053
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
1054
+
1055
+ # import IPython
1056
+ # IPython.embed()
1057
+
1058
+ if self.qargs.use_quantize_model:
1059
+ # if False:
1060
+ x, qkv = self.BeforeAttention(x, qx, sx)
1061
+ else:
1062
+ # apply norm before
1063
+ h = self.attn_norm(x)
1064
+
1065
+ qkv = self.BeforeAttention.att_proj(h)
1066
+
1067
+ if self.config.clip_qkv is not None:
1068
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1069
+
1070
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
1071
+
1072
+ # Get attention scores.
1073
+ att, cache = self.attention(
1074
+ q,
1075
+ k,
1076
+ v,
1077
+ attention_bias,
1078
+ layer_past=layer_past,
1079
+ use_cache=use_cache,
1080
+ max_doc_len=max_doc_len,
1081
+ cu_doc_lens=cu_doc_lens,
1082
+ )
1083
+
1084
+ # import IPython
1085
+ # IPython.embed()
1086
+ if self.qargs.use_quantize_model:
1087
+ # if False:
1088
+ x, qx, sx = self.AfterAttention(x, att)
1089
+ else:
1090
+ att = self.AfterAttention.attn_out(att)
1091
+
1092
+ # Add attention scores.
1093
+ # shape: (B, T, C)
1094
+ x = x + self.dropout(att)
1095
+
1096
+ if self.qargs.use_quantize_model:
1097
+ # if False:
1098
+ x, qx, sx = self.MLPResidual(x, qx, sx)
1099
+ else:
1100
+ # Add feed-forward projection.
1101
+ # shape: (batch_size, seq_len, d_model)
1102
+ og_x = x
1103
+
1104
+ x = self.ff_norm(x)
1105
+
1106
+ x = self.MLPResidual.ff_proj(x)
1107
+
1108
+ if self._activation_checkpoint_fn is not None:
1109
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1110
+ else:
1111
+ x = self.act(x)
1112
+ x = self.MLPResidual.ff_out(x)
1113
+
1114
+ x = self.dropout(x)
1115
+ x = og_x + x
1116
+
1117
+ # import IPython
1118
+ # IPython.embed()
1119
+
1120
+ return x, qx, sx, cache
1121
+
1122
+
1123
+ class CoatOLMoLlamaBlock(OLMoBlock):
1124
+ """
1125
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
1126
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
1127
+ but some operations have slightly different implementations to imitate the
1128
+ behavior of Llama.
1129
+ """
1130
+
1131
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
1132
+ super().__init__(layer_id, config, qargs, cache)
1133
+ # Layer norms.
1134
+ self.attn_norm = LayerNorm.build(config)
1135
+ self.ff_norm = LayerNorm.build(config)
1136
+ self.__cache = cache
1137
+
1138
+ # Attention input projection. Projects x -> (q, k, v)
1139
+ if config.multi_query_attention:
1140
+ q_proj_out_dim = config.d_model
1141
+ k_proj_out_dim = config.d_model // config.n_heads
1142
+ v_proj_out_dim = config.d_model // config.n_heads
1143
+ else:
1144
+ q_proj_out_dim = config.d_model
1145
+ k_proj_out_dim = config.d_model
1146
+ v_proj_out_dim = config.d_model
1147
+ self.q_proj = nn.Linear(config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device)
1148
+ self.k_proj = nn.Linear(config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device)
1149
+ self.v_proj = nn.Linear(config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device)
1150
+
1151
+ # Feed-forward input projection.
1152
+ self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device)
1153
+
1154
+ def reset_parameters(self):
1155
+ super().reset_parameters()
1156
+ self.attn_norm.reset_parameters()
1157
+ self.ff_norm.reset_parameters()
1158
+ # NOTE: the standard deviation for these weights does not depend on the layer.
1159
+
1160
+ if self.config.init_fn == InitFnType.normal:
1161
+ std = self.config.init_std
1162
+ cutoff_factor = self.config.init_cutoff_factor
1163
+ elif self.config.init_fn == InitFnType.mitchell:
1164
+ std = 1 / math.sqrt(self.config.d_model)
1165
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1166
+ elif self.config.init_fn == InitFnType.full_megatron:
1167
+ std = self.config.init_std
1168
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
1169
+ else:
1170
+ raise NotImplementedError(self.config.init_fn)
1171
+
1172
+ init_normal(self.q_proj, std, cutoff_factor)
1173
+ init_normal(self.k_proj, std, cutoff_factor)
1174
+ init_normal(self.v_proj, std, cutoff_factor)
1175
+ init_normal(self.ff_proj, std, cutoff_factor)
1176
+
1177
+ def _scaled_dot_product_attention(
1178
+ self,
1179
+ q: torch.Tensor,
1180
+ k: torch.Tensor,
1181
+ v: torch.Tensor,
1182
+ attn_mask: torch.Tensor | None = None,
1183
+ dropout_p: float = 0.0,
1184
+ is_causal: bool = False,
1185
+ max_doc_len: int | None = None,
1186
+ cu_doc_lens: torch.Tensor | None = None,
1187
+ ) -> torch.Tensor:
1188
+ if max_doc_len is not None or cu_doc_lens is not None:
1189
+ raise NotImplementedError(f"attention document masking is not implemented for {self.__class__.__name__}")
1190
+
1191
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
1192
+
1193
+ if is_causal:
1194
+ assert attn_mask is None
1195
+
1196
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
1197
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
1198
+ elif attn_mask is not None:
1199
+ attn_bias = attn_mask.to(q.dtype)
1200
+ else:
1201
+ attn_bias = torch.zeros_like(attn_weights)
1202
+
1203
+ attn_weights += attn_bias
1204
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
1205
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
1206
+ return torch.matmul(attn_weights, v)
1207
+
1208
+ def forward(
1209
+ self,
1210
+ x: torch.Tensor,
1211
+ qx: torch.Tensor,
1212
+ sx: torch.Tensor,
1213
+ attention_bias: torch.Tensor | None = None,
1214
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
1215
+ use_cache: bool = False,
1216
+ max_doc_len: int | None = None,
1217
+ cu_doc_lens: torch.Tensor | None = None,
1218
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
1219
+ # Get query, key, value projections.
1220
+ # shape:
1221
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
1222
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
1223
+ # k, v: (batch_size, seq_len, d_model // n_heads)
1224
+ x_normed = self.attn_norm(x)
1225
+ q = self.q_proj(x_normed)
1226
+ k = self.k_proj(x_normed)
1227
+ v = self.v_proj(x_normed)
1228
+
1229
+ if self.config.clip_qkv is not None:
1230
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1231
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1232
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1233
+
1234
+ # Get attention scores.
1235
+ att, cache = self.attention(
1236
+ q,
1237
+ k,
1238
+ v,
1239
+ attention_bias,
1240
+ layer_past=layer_past,
1241
+ use_cache=use_cache,
1242
+ max_doc_len=max_doc_len,
1243
+ cu_doc_lens=cu_doc_lens,
1244
+ )
1245
+
1246
+ att = self.attn_out(att) # NOTE: we move the attn_out outside the self.attention module
1247
+
1248
+ # Add attention scores.
1249
+ # shape: (B, T, C)
1250
+ x = x + self.dropout(att)
1251
+
1252
+ # Add feed-forward projection.
1253
+ # shape: (batch_size, seq_len, d_model)
1254
+ og_x = x
1255
+ if self._activation_checkpoint_fn is not None:
1256
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1257
+ else:
1258
+ x = self.ff_norm(x)
1259
+ x = self.ff_proj(x)
1260
+ if self._activation_checkpoint_fn is not None:
1261
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1262
+ else:
1263
+ x = self.act(x)
1264
+ x = self.ff_out(x)
1265
+ x = self.dropout(x)
1266
+ x = og_x + x
1267
+
1268
+ return x, cache
1269
+
1270
+
1271
+ class CoatOLMoBlockGroup(nn.ModuleList):
1272
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Iterable[nn.Module] | None = None):
1273
+ super().__init__(modules)
1274
+ self.config = config
1275
+ self.layer_offset = layer_offset
1276
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
1277
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1278
+
1279
+ def forward(
1280
+ self,
1281
+ x: torch.Tensor,
1282
+ attention_bias: torch.FloatTensor | None = None,
1283
+ layers_past: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
1284
+ use_cache: bool = False,
1285
+ max_doc_len: int | None = None,
1286
+ cu_doc_lens: torch.Tensor | None = None,
1287
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]] | None]:
1288
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
1289
+ for block_idx, block in enumerate(self):
1290
+ layer_past = None if layers_past is None else layers_past[block_idx]
1291
+ block_idx += self.layer_offset
1292
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
1293
+ # shape: (batch_size, seq_len, d_model)
1294
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1295
+ block,
1296
+ x,
1297
+ attention_bias=attention_bias,
1298
+ layer_past=layer_past,
1299
+ use_cache=use_cache,
1300
+ max_doc_len=max_doc_len,
1301
+ cu_doc_lens=cu_doc_lens,
1302
+ )
1303
+ else:
1304
+ # shape: (batch_size, seq_len, d_model)
1305
+ x, cache = block(
1306
+ x,
1307
+ attention_bias=attention_bias,
1308
+ layer_past=layer_past,
1309
+ use_cache=use_cache,
1310
+ max_doc_len=max_doc_len,
1311
+ cu_doc_lens=cu_doc_lens,
1312
+ )
1313
+ if attn_key_values is not None:
1314
+ assert cache is not None
1315
+ attn_key_values.append(cache)
1316
+ return x, attn_key_values
1317
+
1318
+ def reset_parameters(self):
1319
+ for block in self:
1320
+ block.reset_parameters()
1321
+
1322
+ def set_activation_checkpointing(
1323
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
1324
+ ):
1325
+ self.activation_checkpointing_strategy = strategy
1326
+ for block in self:
1327
+ block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
1328
+
1329
+
1330
+ class CoatOLMo(nn.Module):
1331
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, init_params: bool = True):
1332
+ super().__init__()
1333
+ self.config = config
1334
+ self.qargs = qargs
1335
+ self.__cache = BufferCache()
1336
+
1337
+ # Validate config.
1338
+ if self.config.alibi and self.config.flash_attention:
1339
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
1340
+
1341
+ if self.config.alibi and self.config.rope:
1342
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
1343
+
1344
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1345
+ if self.config.embedding_size < self.config.vocab_size:
1346
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
1347
+ elif self.config.embedding_size % 128 != 0:
1348
+ import warnings
1349
+
1350
+ warnings.warn(
1351
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1352
+ )
1353
+
1354
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
1355
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1356
+
1357
+ if not (
1358
+ 0 < self.config.block_group_size <= self.config.n_layers
1359
+ and self.config.n_layers % self.config.block_group_size == 0
1360
+ ):
1361
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
1362
+
1363
+ torch.backends.cuda.enable_flash_sdp(True)
1364
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1365
+
1366
+ self.transformer = nn.ModuleDict(
1367
+ dict(
1368
+ wte=nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device),
1369
+ emb_drop=Dropout(config.embedding_dropout),
1370
+ ln_f=LayerNorm.build(config),
1371
+ )
1372
+ )
1373
+
1374
+ blocks = [CoatOLMoBlock.build(i, config, qargs, self.__cache) for i in range(config.n_layers)]
1375
+ if self.config.block_group_size > 1:
1376
+ block_groups = [
1377
+ CoatOLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
1378
+ for i in range(0, config.n_layers, config.block_group_size)
1379
+ ]
1380
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1381
+ else:
1382
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1383
+
1384
+ if not (self.config.alibi or self.config.rope):
1385
+ self.transformer.update(
1386
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1387
+ )
1388
+ if not config.weight_tying:
1389
+ self.transformer.update(
1390
+ {
1391
+ "ff_out": nn.Linear(
1392
+ config.d_model,
1393
+ config.embedding_size or config.vocab_size,
1394
+ bias=config.include_bias,
1395
+ device=config.init_device,
1396
+ )
1397
+ }
1398
+ )
1399
+ if config.embedding_layer_norm:
1400
+ self.transformer.update({"emb_norm": LayerNorm.build(config)})
1401
+
1402
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1403
+ if init_params and self.config.init_device != "meta":
1404
+ self.reset_parameters()
1405
+ self.__num_fwd_flops: int | None = None
1406
+ self.__num_bck_flops: int | None = None
1407
+
1408
+ # Warm up cache.
1409
+ if self.config.alibi:
1410
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1411
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1412
+
1413
+ # Quantize
1414
+ self.quantize_input_before_block = Coat_quantize_bgn(qargs)
1415
+ self.quantize_output_after_block = Coat_quantize_end(qargs)
1416
+
1417
+ set_activation_checkpointing = OLMo.set_activation_checkpointing
1418
+ device = OLMo.device
1419
+ reset_parameters = OLMo.reset_parameters
1420
+ get_alibi_attention_bias = OLMo.get_alibi_attention_bias
1421
+
1422
+ def forward(
1423
+ self,
1424
+ input_ids: torch.LongTensor,
1425
+ input_embeddings: torch.FloatTensor | None = None,
1426
+ attention_mask: torch.Tensor | None = None,
1427
+ attention_bias: torch.Tensor | None = None,
1428
+ past_key_values: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None,
1429
+ use_cache: bool = False,
1430
+ last_logits_only: bool = False,
1431
+ output_hidden_states: bool | None = None,
1432
+ doc_lens: torch.Tensor | None = None,
1433
+ max_doc_lens: Sequence[int] | None = None,
1434
+ ) -> OLMoOutput:
1435
+ """
1436
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1437
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1438
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1439
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1440
+ which input IDs are masked. A `1` value in the mask means that
1441
+ the corresponding input ID should *not* be ignored. A `0` means
1442
+ that the corresponding input ID is masked.
1443
+
1444
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1445
+ library.
1446
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1447
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1448
+ to introduce causal or other biases.
1449
+
1450
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1451
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1452
+ element in the sequence.
1453
+
1454
+ If the tensor is a float tensor, it will just be added to the attention
1455
+ scores before the softmax.
1456
+
1457
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1458
+ :param past_key_values: Pre-computed keys and values for each attention block.
1459
+ Can be used to speed up sequential decoding. The `input_ids` which have
1460
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1461
+ :param use_cache: If `True`, return key and value tensors for each block.
1462
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1463
+ This can speed up decoding when you only care about the next token.
1464
+ :param doc_lens: Document lengths to use in attention for intra-document masking.
1465
+ Shape `(batch_size, max_docs)`.
1466
+ :param max_doc_lens: Maximum document length for each instance in the batch.
1467
+ """
1468
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1469
+
1470
+ if past_key_values:
1471
+ assert len(past_key_values) == self.config.n_layers
1472
+
1473
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1474
+ if past_key_values is None:
1475
+ past_length = 0
1476
+ else:
1477
+ past_length = past_key_values[0][0].size(-2)
1478
+
1479
+ max_doc_len: int | None = None
1480
+ cu_doc_lens: torch.Tensor | None = None
1481
+ if doc_lens is not None and max_doc_lens is not None:
1482
+ max_doc_len = max(max_doc_lens)
1483
+ cu_doc_lens = get_cumulative_document_lengths(doc_lens)
1484
+
1485
+ # Get embeddings of input.
1486
+ # shape: (batch_size, seq_len, d_model)
1487
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1488
+
1489
+ # Apply embedding layer norm.
1490
+ if self.config.embedding_layer_norm:
1491
+ x = self.transformer.emb_norm(x)
1492
+
1493
+ if not (self.config.alibi or self.config.rope):
1494
+ # Get positional embeddings.
1495
+ # shape: (1, seq_len)
1496
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1497
+ # shape: (1, seq_len, d_model)
1498
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1499
+ x = pos_emb + x
1500
+
1501
+ # Apply dropout.
1502
+ # shape: (batch_size, seq_len, d_model)
1503
+ x = self.transformer.emb_drop(x) # type: ignore
1504
+
1505
+ # Transform the attention mask into what the blocks expect.
1506
+ if attention_mask is not None:
1507
+ # shape: (batch_size, 1, 1, seq_len)
1508
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1509
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1510
+
1511
+ # Merge attention mask with attention bias.
1512
+ if (
1513
+ attention_bias is not None
1514
+ or attention_mask is not None
1515
+ or self.config.alibi
1516
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1517
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1518
+ # scores correctly.
1519
+ or past_key_values is not None
1520
+ ):
1521
+ if attention_bias is None and self.config.alibi:
1522
+ attention_bias = get_causal_attention_bias(
1523
+ self.__cache, past_length + seq_len, x.device
1524
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1525
+ elif attention_bias is None:
1526
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1527
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1528
+ attention_bias = attention_bias.to(dtype=torch.float)
1529
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1530
+
1531
+ # Transform to the right shape and data type.
1532
+ mask_len = seq_len
1533
+ if attention_mask is not None:
1534
+ mask_len = attention_mask.shape[-1]
1535
+ elif past_key_values is not None:
1536
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1537
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1538
+
1539
+ # Add in the masking bias.
1540
+ if attention_mask is not None:
1541
+ attention_bias = attention_bias + attention_mask
1542
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1543
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1544
+ # it can produce NaNs.
1545
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1546
+
1547
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
1548
+
1549
+ # decoder layers
1550
+ all_hidden_states = []
1551
+
1552
+ # Prepare the input for COAT decoderlayer
1553
+ x, qx, sx = self.quantize_input_before_block(x)
1554
+
1555
+ # Apply blocks one-by-one.
1556
+ if self.config.block_group_size == 1:
1557
+ for block_idx, block in enumerate(self.transformer.blocks):
1558
+ if output_hidden_states:
1559
+ # add hidden states
1560
+ all_hidden_states.append(x)
1561
+
1562
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1563
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
1564
+ # shape: (batch_size, seq_len, d_model)
1565
+ x, qx, sx, cache = self._activation_checkpoint_fn(
1566
+ block,
1567
+ x,
1568
+ qx,
1569
+ sx,
1570
+ attention_bias=attention_bias,
1571
+ layer_past=layer_past,
1572
+ use_cache=use_cache,
1573
+ max_doc_len=max_doc_len,
1574
+ cu_doc_lens=cu_doc_lens,
1575
+ )
1576
+ else:
1577
+ # shape: (batch_size, seq_len, d_model)
1578
+ x, qx, sx, cache = block(
1579
+ x,
1580
+ qx,
1581
+ sx,
1582
+ attention_bias=attention_bias,
1583
+ layer_past=layer_past,
1584
+ use_cache=use_cache,
1585
+ max_doc_len=max_doc_len,
1586
+ cu_doc_lens=cu_doc_lens,
1587
+ )
1588
+
1589
+ if attn_key_values is not None:
1590
+ assert cache is not None
1591
+ attn_key_values.append(cache)
1592
+ else:
1593
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1594
+ if output_hidden_states:
1595
+ # add hidden states
1596
+ all_hidden_states.append(x)
1597
+
1598
+ layers_past = (
1599
+ None
1600
+ if past_key_values is None
1601
+ else past_key_values[
1602
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1603
+ ]
1604
+ )
1605
+ x, cache = block_group(
1606
+ x,
1607
+ attention_bias=attention_bias,
1608
+ layers_past=layers_past,
1609
+ use_cache=use_cache,
1610
+ max_doc_len=max_doc_len,
1611
+ cu_doc_lens=cu_doc_lens,
1612
+ )
1613
+ if attn_key_values is not None:
1614
+ assert cache is not None
1615
+ attn_key_values.extend(cache)
1616
+
1617
+ # Summarize the output of the Decoder Layer
1618
+ x = self.quantize_output_after_block(x, qx, sx)
1619
+
1620
+ if last_logits_only:
1621
+ # shape: (batch_size, 1, d_model)
1622
+ x = x[:, -1, :].unsqueeze(1)
1623
+
1624
+ # Apply final layer norm.
1625
+ # shape: (batch_size, seq_len or 1, d_model)
1626
+ x = self.transformer.ln_f(x) # type: ignore
1627
+ if output_hidden_states:
1628
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1629
+ all_hidden_states.append(x)
1630
+
1631
+ # Get logits.
1632
+ # shape: (batch_size, seq_len or 1, vocab_size)
1633
+ if self.config.weight_tying:
1634
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1635
+ else:
1636
+ logits = self.transformer.ff_out(x) # type: ignore
1637
+ if self.config.scale_logits:
1638
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1639
+
1640
+ return OLMoOutput(
1641
+ logits=logits,
1642
+ attn_key_values=attn_key_values,
1643
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1644
+ )
1645
+
1646
+ def get_fsdp_wrap_policy(self, wrap_strategy: FSDPWrapStrategy | None = None):
1647
+ if wrap_strategy is None:
1648
+ return None
1649
+
1650
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
1651
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
1652
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1653
+ # but not other linear layers within a block.
1654
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1655
+ # return True in 'recurse' mode for simplicity.
1656
+ size_based_module_to_wrap = {self.transformer.wte}
1657
+ if hasattr(self.transformer, "ff_out"):
1658
+ size_based_module_to_wrap.add(self.transformer.ff_out)
1659
+
1660
+ if wrap_strategy == FSDPWrapStrategy.by_block:
1661
+
1662
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1663
+ del nonwrapped_numel
1664
+ wrap = isinstance(module, CoatOLMoBlock)
1665
+ if recurse:
1666
+ return True
1667
+ else:
1668
+ return wrap
1669
+
1670
+ return fsdp_wrap_fn
1671
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1672
+
1673
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1674
+ del nonwrapped_numel
1675
+ wrap = isinstance(module, (CoatOLMoBlock,)) or module in size_based_module_to_wrap
1676
+ if recurse:
1677
+ return True
1678
+ else:
1679
+ return wrap
1680
+
1681
+ return fsdp_wrap_fn
1682
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1683
+ if self.config.block_group_size <= 1:
1684
+ raise OLMoConfigurationError(
1685
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1686
+ )
1687
+
1688
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1689
+ del nonwrapped_numel
1690
+ wrap = isinstance(module, CoatOLMoBlockGroup)
1691
+ if recurse:
1692
+ return True
1693
+ else:
1694
+ return wrap
1695
+
1696
+ return fsdp_wrap_fn
1697
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1698
+ if self.config.block_group_size <= 1:
1699
+ raise OLMoConfigurationError(
1700
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1701
+ )
1702
+
1703
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1704
+ del nonwrapped_numel
1705
+ wrap = isinstance(module, (CoatOLMoBlockGroup,)) or module in size_based_module_to_wrap
1706
+ if recurse:
1707
+ return True
1708
+ else:
1709
+ return wrap
1710
+
1711
+ return fsdp_wrap_fn
1712
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
1713
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1714
+
1715
+ return size_based_auto_wrap_policy
1716
+ elif wrap_strategy in {
1717
+ FSDPWrapStrategy.one_in_two,
1718
+ FSDPWrapStrategy.one_in_three,
1719
+ FSDPWrapStrategy.one_in_four,
1720
+ FSDPWrapStrategy.one_in_five,
1721
+ }:
1722
+ c = {
1723
+ FSDPWrapStrategy.one_in_two: 2,
1724
+ FSDPWrapStrategy.one_in_three: 3,
1725
+ FSDPWrapStrategy.one_in_four: 4,
1726
+ FSDPWrapStrategy.one_in_five: 5,
1727
+ }[wrap_strategy]
1728
+
1729
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1730
+ del nonwrapped_numel
1731
+ wrap = isinstance(module, CoatOLMoBlock) and module.layer_id % c == 0
1732
+ if recurse:
1733
+ return True
1734
+ else:
1735
+ return wrap
1736
+
1737
+ return fsdp_wrap_fn
1738
+ else:
1739
+ raise NotImplementedError(wrap_strategy)
1740
+
1741
+ num_params = OLMo.num_params
1742
+
1743
+ @property
1744
+ def num_fwd_flops(self):
1745
+ if self.__num_fwd_flops:
1746
+ return self.__num_fwd_flops
1747
+
1748
+ # embedding table is just a lookup in the forward pass
1749
+ n_params = self.num_params(include_embedding=False)
1750
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1751
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1752
+ # this gets us FLOPs / token
1753
+ params_flops_per_token = 2 * n_params
1754
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1755
+ attn_flops_per_token = self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length)
1756
+ self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token
1757
+ return self.__num_fwd_flops
1758
+
1759
+ @property
1760
+ def num_bck_flops(self):
1761
+ if self.__num_bck_flops:
1762
+ return self.__num_bck_flops
1763
+
1764
+ n_params = self.num_params()
1765
+ params_flops_per_token = 4 * n_params
1766
+ attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length)
1767
+ self.__num_bck_flops = params_flops_per_token + attn_flops_per_token
1768
+ return self.__num_bck_flops
1769
+
1770
+ generate = OLMo.generate
1771
+
1772
+ @classmethod
1773
+ def from_checkpoint(
1774
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: CheckpointType | None = None
1775
+ ) -> CoatOLMo:
1776
+ """
1777
+ Load an OLMo model from a checkpoint.
1778
+ """
1779
+ from olmo.util import resource_path
1780
+
1781
+ # Guess checkpoint type.
1782
+ if checkpoint_type is None:
1783
+ try:
1784
+ if resource_path(checkpoint_dir, "model.pt").is_file():
1785
+ checkpoint_type = CheckpointType.unsharded
1786
+ else:
1787
+ checkpoint_type = CheckpointType.sharded
1788
+ except FileNotFoundError:
1789
+ checkpoint_type = CheckpointType.sharded
1790
+
1791
+ # Load config.
1792
+ config_path = resource_path(checkpoint_dir, "config.yaml")
1793
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1794
+
1795
+ if checkpoint_type == CheckpointType.unsharded:
1796
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1797
+ model_config.init_device = "cpu"
1798
+ model = CoatOLMo(model_config)
1799
+
1800
+ # Load state dict directly to target device.
1801
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
1802
+ state_dict = torch.load(state_dict_path, map_location="cpu")
1803
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1804
+ model = model.to(torch.device(device))
1805
+ else:
1806
+ train_config = TrainConfig.load(config_path)
1807
+ if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
1808
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state # type: ignore
1809
+
1810
+ model_config.init_device = device
1811
+ model = CoatOLMo(model_config)
1812
+ load_model_and_optim_state(checkpoint_dir, model)
1813
+ else:
1814
+ # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
1815
+ from olmo.checkpoint import load_model_state
1816
+
1817
+ # Initialize model on target device. In this case the state dict is loaded in-place
1818
+ # so it's not necessary to start on CPU if the target device is a GPU.
1819
+ model_config.init_device = device
1820
+ model = CoatOLMo(model_config)
1821
+
1822
+ # Load state dict in place.
1823
+ load_model_state(checkpoint_dir, model)
1824
+
1825
+ return model.eval()
1826
+
1827
+ def _make_state_dict_compatible(
1828
+ self, state_dict: dict[str, torch.Tensor]
1829
+ ) -> tuple[dict[str, torch.Tensor], dict[str, set[str]]]:
1830
+ """
1831
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
1832
+ be loaded.
1833
+
1834
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
1835
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
1836
+ to make a corresponding optimizer state dict compatible as well.
1837
+ """
1838
+ import re
1839
+ from fnmatch import fnmatch
1840
+
1841
+ new_keys_to_og_keys: dict[str, str] = {}
1842
+
1843
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1844
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1845
+ # fine without the prefixes. This also simplifies the other steps below.
1846
+ for key in list(state_dict.keys()):
1847
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1848
+ new_keys_to_og_keys[new_key] = key
1849
+
1850
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1851
+ if self.config.block_type == BlockType.sequential:
1852
+ for key in list(state_dict.keys()):
1853
+ if fnmatch(key, "transformer.*.norm.weight"):
1854
+ tensor = state_dict.pop(key)
1855
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1856
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1857
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1858
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1859
+ del new_keys_to_og_keys[key]
1860
+ elif fnmatch(key, "transformer.*.norm.bias"):
1861
+ tensor = state_dict.pop(key)
1862
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1863
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1864
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1865
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1866
+ del new_keys_to_og_keys[key]
1867
+
1868
+ # Realquantization will change the place the linear layers happen
1869
+ if self.qargs.use_quantize_model == "coat_real":
1870
+ for key in list(state_dict.keys()):
1871
+ if fnmatch(key, "transformer.blocks.*.att_proj.weight") and "BeforeAttention" not in key:
1872
+ tensor = state_dict.pop(key)
1873
+ state_dict[(new_key := key.replace("att_proj.weight", "BeforeAttention.att_proj.weight"))] = tensor
1874
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1875
+ del new_keys_to_og_keys[key]
1876
+ elif fnmatch(key, "transformer.blocks.*.attn_out.weight") and "AfterAttention" not in key:
1877
+ tensor = state_dict.pop(key)
1878
+ state_dict[(new_key := key.replace("attn_out.weight", "AfterAttention.attn_out.weight"))] = tensor
1879
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1880
+ del new_keys_to_og_keys[key]
1881
+ elif fnmatch(key, "transformer.blocks.*.ff_proj.weight") and "MLPResidual" not in key:
1882
+ tensor = state_dict.pop(key)
1883
+ state_dict[(new_key := key.replace("ff_proj.weight", "MLPResidual.ff_proj.weight"))] = tensor
1884
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1885
+ del new_keys_to_og_keys[key]
1886
+ elif fnmatch(key, "transformer.blocks.*.ff_out.weight") and "MLPResidual" not in key:
1887
+ tensor = state_dict.pop(key)
1888
+ state_dict[(new_key := key.replace("ff_out.weight", "MLPResidual.ff_out.weight"))] = tensor
1889
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1890
+ del new_keys_to_og_keys[key]
1891
+
1892
+ # For loading a state dict that was saved with a different `block_group_size`.
1893
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1894
+ state_dict_block_group_size = len(
1895
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1896
+ )
1897
+ else:
1898
+ state_dict_block_group_size = 1
1899
+ if self.config.block_group_size != state_dict_block_group_size:
1900
+ log.info(
1901
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1902
+ f"group size {self.config.block_group_size}"
1903
+ )
1904
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1905
+ # and then (re-)group them into the right block sizes.
1906
+ if state_dict_block_group_size > 1:
1907
+ for key in list(state_dict.keys()):
1908
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1909
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1910
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1911
+ state_dict[
1912
+ (
1913
+ new_key := key.replace(
1914
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1915
+ )
1916
+ )
1917
+ ] = state_dict.pop(key)
1918
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1919
+
1920
+ if self.config.block_group_size > 1:
1921
+ # Group the state dict blocks into the right block size.
1922
+ for key in list(state_dict.keys()):
1923
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1924
+ block_idx = int(m.group(1))
1925
+ group_idx, group_block_idx = (
1926
+ block_idx // self.config.block_group_size,
1927
+ block_idx % self.config.block_group_size,
1928
+ )
1929
+ state_dict[
1930
+ (
1931
+ new_key := key.replace(
1932
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1933
+ )
1934
+ )
1935
+ ] = state_dict.pop(key)
1936
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1937
+
1938
+ og_keys_to_new: dict[str, set[str]] = defaultdict(set)
1939
+ for new_key, og_key in new_keys_to_og_keys.items():
1940
+ og_keys_to_new[og_key].add(new_key)
1941
+
1942
+ return state_dict, og_keys_to_new
llava/model/coat/activation/real_quantization/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Activation
8
+ # Utils
9
+ from ._dequantize import fp8_dequantize
10
+ from ._division import fp8_division
11
+ from ._division_transpose import fp8_division_transpose
12
+ from ._quantize import fp8_quantize
13
+ from ._quantize_pertensor import fp8_quantize_pertensor
14
+ from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
15
+ from ._transpose import fp8_transpose
16
+ from .add_bwd import fp8_add_Ifp_Ifp_Ofp_Opt
17
+ from .add_fwd import fp8_add_Ifp_Ifp_Ofp_Og16
18
+
19
+ # Normalization
20
+ from .func_layernorm_noparam import fp8_layernorm_noparam_backward, fp8_layernorm_noparam_forward
21
+ from .func_quantize import Coat_quantize_bgn, Coat_quantize_end
22
+ from .func_rmsnorm import fp8_rmsnorm_backward, fp8_rmsnorm_forward
23
+ from .gelu_bwd import fp8_gelu_backward
24
+ from .gelu_fwd import fp8_gelu_forward
25
+
26
+ # linear and add
27
+ from .linear import fp8_linear_backward, fp8_linear_forward
28
+ from .mul_bwd import fp8_mul_backward
29
+ from .mul_fwd import fp8_mul_forward
30
+ from .silu_bwd import fp8_silu_backward
31
+ from .silu_fwd import fp8_silu_forward
llava/model/coat/activation/real_quantization/_dequantize.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
31
+
32
+ """Quantize Operator"""
33
+ """Input uses 1 * 16 group quantization"""
34
+ """Output uses 1 * 16 group quantization"""
35
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
36
+
37
+
38
+ @triton.autotune(
39
+ configs=[] + get_configs_io_block(),
40
+ key=[
41
+ "N",
42
+ ],
43
+ )
44
+ @triton.heuristics(
45
+ {
46
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
47
+ }
48
+ )
49
+ @triton.jit
50
+ def _fp8_dequantize_kernel(
51
+ output_ptr, # output
52
+ input_ptr,
53
+ input_scale_ptr, # input
54
+ M,
55
+ N,
56
+ SN,
57
+ QB: tl.constexpr, # shape
58
+ input_stride_0,
59
+ input_stride_1, # input stride
60
+ s_input_stride_0,
61
+ s_input_stride_1, # scale of output stride
62
+ output_stride_0,
63
+ output_stride_1, # output stride
64
+ BLOCK_M: tl.constexpr,
65
+ BLOCK_N: tl.constexpr,
66
+ BLOCK_SN: tl.constexpr,
67
+ ): # CUDA block size
68
+
69
+ # Block PID
70
+ pid = tl.program_id(0)
71
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
72
+ pid_dim0 = pid // NUM_BLOCK_N
73
+ pid_dim1 = pid % NUM_BLOCK_N
74
+
75
+ # pointers
76
+ input_block_ptr = tl.make_block_ptr(
77
+ base=input_ptr,
78
+ shape=(M, N),
79
+ strides=(input_stride_0, input_stride_1),
80
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
81
+ block_shape=(BLOCK_M, BLOCK_N),
82
+ order=(1, 0),
83
+ )
84
+ # input ptr
85
+ scale_input_ptr = tl.make_block_ptr(
86
+ base=input_scale_ptr,
87
+ shape=(M, SN),
88
+ strides=(s_input_stride_0, s_input_stride_1),
89
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
90
+ block_shape=(BLOCK_M, BLOCK_SN),
91
+ order=(1, 0),
92
+ )
93
+
94
+ input = tl.load(input_block_ptr)
95
+ scale_input = tl.load(scale_input_ptr)
96
+
97
+ input = input.to(tl.float32)
98
+ scale_input = scale_input.to(tl.float32)
99
+
100
+ # Dequantize and gelu calculation
101
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
102
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
103
+ output = input * scale_input
104
+
105
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
106
+ output = output.to(output_ptr.dtype.element_ty)
107
+
108
+ # debug
109
+ # gelu_output = input
110
+ # scale_output = scale_input
111
+
112
+ # pointers
113
+ output_block_ptr = tl.make_block_ptr(
114
+ base=output_ptr,
115
+ shape=(M, N),
116
+ strides=(output_stride_0, output_stride_1),
117
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
118
+ block_shape=(BLOCK_M, BLOCK_N),
119
+ order=(1, 0),
120
+ )
121
+
122
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
123
+
124
+
125
+ def fp8_dequantize(x, s_x, QB):
126
+ # Change batched 3D input to 2D
127
+ batched = False
128
+ if len(x.shape) == 3:
129
+ batched = True
130
+ BS = x.shape[0]
131
+ x = x.reshape(-1, x.shape[-1])
132
+ s_x = s_x.reshape(-1, s_x.shape[-1])
133
+
134
+ # defining the input and output tensor
135
+ M, N = x.shape
136
+ SN = N // QB
137
+
138
+ y = torch.empty_like(x, dtype=torch.bfloat16)
139
+
140
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
141
+
142
+ _fp8_dequantize_kernel[grid](
143
+ y,
144
+ x,
145
+ s_x,
146
+ M,
147
+ N,
148
+ SN,
149
+ QB,
150
+ x.stride(0),
151
+ x.stride(1),
152
+ s_x.stride(0),
153
+ s_x.stride(1),
154
+ y.stride(0),
155
+ y.stride(1),
156
+ )
157
+
158
+ # Recover 2D to 3D
159
+ if batched:
160
+ y = y.reshape(BS, -1, y.shape[-1])
161
+
162
+ return y
llava/model/coat/activation/real_quantization/_division.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
31
+
32
+ """Quantize and Transpose Operator"""
33
+ """Input uses 1 * 16 group quantization"""
34
+ """Output uses 1 * 16 group quantization"""
35
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
36
+
37
+
38
+ @triton.autotune(
39
+ configs=[] + get_configs_io_block(),
40
+ key=[
41
+ "N",
42
+ ],
43
+ )
44
+ @triton.heuristics(
45
+ {
46
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
47
+ }
48
+ )
49
+ @triton.jit
50
+ def _fp8_division_kernel(
51
+ output_ptr, # output
52
+ input_ptr,
53
+ input_scale_ptr, # input
54
+ noise_ptr, # noise for stochastic
55
+ M,
56
+ N,
57
+ SN,
58
+ QB: tl.constexpr,
59
+ fp8_max,
60
+ e_bit: tl.constexpr,
61
+ m_bit: tl.constexpr, # shape
62
+ input_stride_0,
63
+ input_stride_1, # input stride
64
+ output_stride_0,
65
+ output_stride_1, # output stride
66
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
67
+ STOCHASTIC: tl.constexpr,
68
+ BLOCK_M: tl.constexpr,
69
+ BLOCK_N: tl.constexpr,
70
+ BLOCK_SN: tl.constexpr,
71
+ ): # CUDA block size
72
+
73
+ # Block PID
74
+ pid = tl.program_id(0)
75
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
76
+ pid_dim0 = pid // NUM_BLOCK_N
77
+ pid_dim1 = pid % NUM_BLOCK_N
78
+
79
+ # pointers
80
+ input_block_ptr = tl.make_block_ptr(
81
+ base=input_ptr,
82
+ shape=(M, N),
83
+ strides=(input_stride_0, input_stride_1),
84
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
85
+ block_shape=(BLOCK_M, BLOCK_N),
86
+ order=(1, 0),
87
+ )
88
+
89
+ input = tl.load(input_block_ptr)
90
+ input = input.to(tl.float32)
91
+ scale_output = tl.load(input_scale_ptr)
92
+ scale_output = scale_output.to(tl.float32)
93
+
94
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
95
+
96
+ # Quantize Scale calculation
97
+ # Quantize
98
+ output = tl.fdiv(output, scale_output)
99
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
100
+
101
+ if STOCHASTIC:
102
+ # noise_block_ptr = tl.make_block_ptr(
103
+ # base=noise_ptr,
104
+ # shape=(M, N),
105
+ # strides=(input_stride_0, input_stride_1),
106
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
107
+ # block_shape=(BLOCK_M, BLOCK_N),
108
+ # order=(1, 0)
109
+ # )
110
+ # noise = tl.load(noise_block_ptr)
111
+
112
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
113
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
114
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
115
+ noise = tl.rand(0, noise_offset)
116
+
117
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
118
+
119
+ output = output.to(output_ptr.type.element_ty)
120
+
121
+ # pointers
122
+ output_block_ptr = tl.make_block_ptr(
123
+ base=output_ptr,
124
+ shape=(M, N),
125
+ strides=(output_stride_0, output_stride_1),
126
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
127
+ block_shape=(BLOCK_M, BLOCK_N),
128
+ order=(1, 0),
129
+ )
130
+
131
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
132
+
133
+
134
+ @triton.jit
135
+ def _stochastic_rounding(output, noise, e_bit: tl.constexpr, m_bit: tl.constexpr):
136
+ subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit)
137
+ # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1)
138
+
139
+ output_int32 = tl.cast(output, tl.int32, bitcast=True)
140
+ output_int32 = output_int32 & 0x7F800000
141
+ output_float32 = tl.cast(output_int32, tl.float32, bitcast=True)
142
+ output_exp = tl.maximum(output_float32, subnormal_min)
143
+
144
+ noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * (
145
+ 1 - tl.exp2(m_bit)
146
+ ) # 2^m_bit for normal, 1 for subnormal
147
+
148
+ noise = output_exp * noise / noise_rescale
149
+ sign = 1 - 2 * libdevice.signbit(output)
150
+ output = tl.abs(output) + noise
151
+
152
+ minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal
153
+ output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp)
154
+
155
+ return output
156
+
157
+
158
+ def fp8_division(x, QB, fp8type, s_y=None, stochastic=False):
159
+ # Change batched 3D input to 2D
160
+ batched = False
161
+ if len(x.shape) == 3:
162
+ batched = True
163
+ BS = x.shape[0]
164
+ x = x.reshape(-1, x.shape[-1])
165
+
166
+ if stochastic:
167
+ # noise = torch.zeros_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
168
+ noise = None
169
+ else:
170
+ noise = None
171
+
172
+ # defining the input and output tensor
173
+ M, N = x.shape
174
+ SN = N // QB
175
+
176
+ if isinstance(fp8type, str):
177
+ fp8type = convert_str_to_fp8[fp8type]
178
+
179
+ y = torch.empty_like(x, dtype=fp8type)
180
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
181
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
182
+
183
+ if s_y is None:
184
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
185
+
186
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
187
+
188
+ _fp8_division_kernel[grid](
189
+ y,
190
+ x,
191
+ s_y,
192
+ noise,
193
+ M,
194
+ N,
195
+ SN,
196
+ QB,
197
+ fp8MaxValue,
198
+ e_bit,
199
+ m_bit,
200
+ x.stride(0),
201
+ x.stride(1),
202
+ y.stride(0),
203
+ y.stride(1),
204
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
205
+ STOCHASTIC=stochastic,
206
+ )
207
+
208
+ # Recover 2D to 3D
209
+ if batched:
210
+ y = y.reshape(BS, -1, y.shape[-1])
211
+
212
+ return y, s_y # y_t is expected to be 2D tensor
llava/model/coat/activation/real_quantization/_division_transpose.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from ._division import _stochastic_rounding
31
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
32
+
33
+ """Division and Transpose Operator"""
34
+ """Input uses full-precision/BF16"""
35
+ """Output uses per tensor quantization"""
36
+ """Output_t uses per tensor quantization and is transposed, but is flattened to 2D"""
37
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
38
+
39
+
40
+ @triton.autotune(
41
+ configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)
42
+ # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], #
43
+ key=[
44
+ "N",
45
+ ],
46
+ )
47
+ @triton.heuristics(
48
+ {
49
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
50
+ }
51
+ )
52
+ @triton.jit
53
+ def _fp8_division_transpose_kernel(
54
+ output_ptr,
55
+ output_t_ptr, # output
56
+ input_ptr,
57
+ input_scale_ptr, # input
58
+ noise_ptr, # noise for stochastic
59
+ M,
60
+ N,
61
+ SN,
62
+ QB: tl.constexpr,
63
+ fp8_max,
64
+ e_bit,
65
+ m_bit, # shape
66
+ input_stride_0,
67
+ input_stride_1, # input stride
68
+ output_stride_0,
69
+ output_stride_1, # output stride
70
+ output_t_stride_0,
71
+ output_t_stride_1, # output stride
72
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
73
+ STOCHASTIC: tl.constexpr,
74
+ ONLY_TRANSPOSED: tl.constexpr,
75
+ BLOCK_M: tl.constexpr,
76
+ BLOCK_N: tl.constexpr,
77
+ BLOCK_SN: tl.constexpr,
78
+ ): # CUDA block size
79
+
80
+ # Block PID
81
+ pid = tl.program_id(0)
82
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
83
+ pid_dim0 = pid // NUM_BLOCK_N
84
+ pid_dim1 = pid % NUM_BLOCK_N
85
+
86
+ # pointers
87
+ input_block_ptr = tl.make_block_ptr(
88
+ base=input_ptr,
89
+ shape=(M, N),
90
+ strides=(input_stride_0, input_stride_1),
91
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
92
+ block_shape=(BLOCK_M, BLOCK_N),
93
+ order=(1, 0),
94
+ )
95
+
96
+ input = tl.load(input_block_ptr)
97
+ input = input.to(tl.float32)
98
+ scale_output = tl.load(input_scale_ptr)
99
+ scale_output = scale_output.to(tl.float32)
100
+
101
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
102
+
103
+ # Quantize Scale calculation
104
+ # Quantize
105
+ output = tl.fdiv(output, scale_output)
106
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
107
+
108
+ if STOCHASTIC:
109
+ # noise_block_ptr = tl.make_block_ptr(
110
+ # base=noise_ptr,
111
+ # shape=(M, N),
112
+ # strides=(input_stride_0, input_stride_1),
113
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
114
+ # block_shape=(BLOCK_M, BLOCK_N),
115
+ # order=(1, 0)
116
+ # )
117
+ # noise = tl.load(noise_block_ptr)
118
+
119
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
120
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
121
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
122
+ noise = tl.rand(0, noise_offset)
123
+
124
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
125
+
126
+ output = output.to(output_ptr.type.element_ty)
127
+ # tl.device_print("3: ", output)
128
+ output_t = tl.trans(output)
129
+
130
+ # pointers
131
+ output_block_ptr = tl.make_block_ptr(
132
+ base=output_ptr,
133
+ shape=(M, N),
134
+ strides=(output_stride_0, output_stride_1),
135
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
136
+ block_shape=(BLOCK_M, BLOCK_N),
137
+ order=(1, 0),
138
+ )
139
+ output_t_block_ptr = tl.make_block_ptr(
140
+ base=output_t_ptr,
141
+ shape=(N, M),
142
+ strides=(output_t_stride_0, output_t_stride_1),
143
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
144
+ block_shape=(BLOCK_N, BLOCK_M),
145
+ order=(1, 0),
146
+ )
147
+ if not ONLY_TRANSPOSED:
148
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
149
+ tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1))
150
+
151
+
152
+ def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False):
153
+ # Change batched 3D input to 2D
154
+ batched = False
155
+ if len(x.shape) == 3:
156
+ batched = True
157
+ BS = x.shape[0]
158
+ x = x.reshape(-1, x.shape[-1])
159
+
160
+ if stochastic:
161
+ # noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
162
+ noise = None
163
+ else:
164
+ noise = None
165
+
166
+ # defining the input and output tensor
167
+ M, N = x.shape
168
+ SN = N // QB
169
+
170
+ if isinstance(fp8type, str):
171
+ fp8type = convert_str_to_fp8[fp8type]
172
+
173
+ y = torch.empty_like(x, dtype=fp8type)
174
+ y_t = torch.empty((N, M), dtype=fp8type, device=x.device)
175
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
176
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
177
+
178
+ if s_y is None:
179
+ # print("Warning: do not specify s_y in fp8_division_transpose")
180
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
181
+
182
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
183
+
184
+ _fp8_division_transpose_kernel[grid](
185
+ y,
186
+ y_t,
187
+ x,
188
+ s_y,
189
+ noise,
190
+ M,
191
+ N,
192
+ SN,
193
+ QB,
194
+ fp8MaxValue,
195
+ e_bit,
196
+ m_bit,
197
+ x.stride(0),
198
+ x.stride(1),
199
+ y.stride(0),
200
+ y.stride(1),
201
+ y_t.stride(0),
202
+ y_t.stride(1),
203
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
204
+ STOCHASTIC=stochastic,
205
+ ONLY_TRANSPOSED=only_transposed,
206
+ )
207
+
208
+ if not only_transposed:
209
+ # Recover 2D to 3D
210
+ if batched:
211
+ y = y.reshape(BS, -1, y.shape[-1])
212
+
213
+ return y, s_y, y_t # y_t is expected to be 2D tensor
214
+ else:
215
+ return y_t, s_y
llava/model/coat/activation/real_quantization/_memory_io.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ CONST_BLOCK = 32
31
+
32
+ # The kernel with 1 load operation and 4 store operation
33
+ def get_configs_io_block():
34
+ configs = []
35
+ for nstages in [3, 4, 5, 6]:
36
+ for block_m in [32, 64, 128]:
37
+ for block_n in [32, 64, 128]:
38
+ for nwarps in [4, 8, 16, 32]:
39
+ configs.append(
40
+ triton.Config(
41
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
42
+ num_stages=nstages,
43
+ num_warps=nwarps,
44
+ )
45
+ )
46
+ return configs
47
+
48
+
49
+ @triton.autotune(
50
+ configs=[] + get_configs_io_block(),
51
+ key=[
52
+ "N",
53
+ ],
54
+ )
55
+ @triton.jit
56
+ def bench_memory_io_kernel_forward(
57
+ output_ptr,
58
+ input_ptr,
59
+ M,
60
+ N,
61
+ B: tl.constexpr,
62
+ input_stride_0,
63
+ input_stride_1,
64
+ output_stride_0,
65
+ output_stride_1,
66
+ BLOCK_M: tl.constexpr,
67
+ BLOCK_N: tl.constexpr,
68
+ ):
69
+
70
+ # Block PID
71
+ pid = tl.program_id(0)
72
+ NUM_BLOCK_M = tl.cdiv(M, BLOCK_M)
73
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
74
+ pid_dim0 = pid // NUM_BLOCK_N
75
+ pid_dim1 = pid % NUM_BLOCK_N
76
+
77
+ # pointers
78
+ input_block_ptr = tl.make_block_ptr(
79
+ base=input_ptr,
80
+ shape=(M, N),
81
+ strides=(input_stride_0, input_stride_1),
82
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
83
+ block_shape=(BLOCK_M, BLOCK_N),
84
+ order=(1, 0),
85
+ )
86
+
87
+ input = tl.load(input_block_ptr)
88
+ input = input.to(tl.float32)
89
+
90
+ output = input * 2
91
+
92
+ # pointers
93
+ output_block_ptr = tl.make_block_ptr(
94
+ base=output_ptr,
95
+ shape=(M, N),
96
+ strides=(output_stride_0, output_stride_1),
97
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
98
+ block_shape=(BLOCK_M, BLOCK_N),
99
+ order=(1, 0),
100
+ )
101
+
102
+ output = output.to(output_ptr.type.element_ty)
103
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
104
+
105
+
106
+ def bench_memory_io_forward(x, B):
107
+ # defining the input and output tensor
108
+ M, N = x.shape
109
+
110
+ y = torch.empty_like(x, dtype=x.dtype)
111
+
112
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
113
+
114
+ bench_memory_io_kernel_forward[grid](
115
+ y,
116
+ x,
117
+ M,
118
+ N,
119
+ B,
120
+ x.stride(0),
121
+ x.stride(1),
122
+ y.stride(0),
123
+ y.stride(1),
124
+ )
125
+ return y
126
+
127
+
128
+ configs = []
129
+ for SL in [8192]:
130
+ configs.append(
131
+ triton.testing.Benchmark( # test different matrix size influence
132
+ x_names=["CDIM"],
133
+ x_vals=[1024, 2048, 4096, 8192],
134
+ line_arg="dtype",
135
+ line_vals=[torch.int8, torch.float16, torch.float32],
136
+ line_names=["float8", "float16", "float32"],
137
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
138
+ ylabel="time-cost",
139
+ plot_name=f"INT8GELU<BLSZ={CONST_BLOCK}><SL={SL}>",
140
+ args={"SL": SL, "B": CONST_BLOCK, "provider": "triton", "mode": "time-consuming"},
141
+ )
142
+ )
143
+
144
+
145
+ @triton.testing.perf_report(configs)
146
+ def bench_load_store(
147
+ SL, CDIM, B, provider, dtype, mode="forward"
148
+ ): # I only use triton as the provider, and mode when benchmarking
149
+ # create data
150
+ x = torch.randn(SL, CDIM, dtype=torch.float32).cuda()
151
+ x = x.to(dtype)
152
+
153
+ quantiles = [0.5, 0.2, 0.8]
154
+ # utility functions
155
+ if provider == "triton":
156
+
157
+ def y_fwd():
158
+ bench_memory_io_forward(x, B)
159
+
160
+ if provider == "torch":
161
+ torch_gelu = torch.nn.GELU()
162
+
163
+ def y_fwd():
164
+ return torch_gelu(x)
165
+
166
+ # forward pass
167
+ if mode == "time-consuming":
168
+ convert_func = lambda ms: ms
169
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
170
+ # backward pass
171
+ if mode == "gbps":
172
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
173
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
174
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ torch.manual_seed(0)
179
+ torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3)
180
+ bench_load_store.run(print_data=True)
llava/model/coat/activation/real_quantization/_quantize.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
31
+
32
+ """Quantize Operator"""
33
+ """Input uses 1 * 16 group quantization"""
34
+ """Output uses 1 * 16 group quantization"""
35
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
36
+
37
+
38
+ @triton.autotune(
39
+ configs=[] + get_configs_io_block(),
40
+ key=[
41
+ "N",
42
+ ],
43
+ )
44
+ @triton.heuristics(
45
+ {
46
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
47
+ }
48
+ )
49
+ @triton.jit
50
+ def _fp8_quantize_kernel(
51
+ output_ptr,
52
+ output_scale_ptr, # output
53
+ input_ptr, # input
54
+ M,
55
+ N,
56
+ SN,
57
+ QB: tl.constexpr,
58
+ fp8_max, # shape
59
+ input_stride_0,
60
+ input_stride_1, # input stride
61
+ output_stride_0,
62
+ output_stride_1, # output stride
63
+ s_output_stride_0,
64
+ s_output_stride_1, # scale of output stride
65
+ SCALE_MIN_THRES: tl.constexpr,
66
+ BLOCK_M: tl.constexpr,
67
+ BLOCK_N: tl.constexpr,
68
+ BLOCK_SN: tl.constexpr,
69
+ ): # CUDA block size
70
+
71
+ # Block PID
72
+ pid = tl.program_id(0)
73
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
74
+ pid_dim0 = pid // NUM_BLOCK_N
75
+ pid_dim1 = pid % NUM_BLOCK_N
76
+
77
+ # pointers
78
+ input_block_ptr = tl.make_block_ptr(
79
+ base=input_ptr,
80
+ shape=(M, N),
81
+ strides=(input_stride_0, input_stride_1),
82
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
83
+ block_shape=(BLOCK_M, BLOCK_N),
84
+ order=(1, 0),
85
+ )
86
+
87
+ input = tl.load(input_block_ptr)
88
+ input = input.to(tl.float32)
89
+
90
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
91
+
92
+ # Quantize Scale calculation
93
+ abs_output = tl.abs(output)
94
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
95
+ scale_output = max_val / fp8_max
96
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
97
+
98
+ # Quantize
99
+ output = tl.fdiv(output, scale_output)
100
+
101
+ output = output.to(output_ptr.type.element_ty)
102
+
103
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
104
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
105
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
106
+
107
+ # debug
108
+ # gelu_output = input
109
+ # scale_output = scale_input
110
+
111
+ # pointers
112
+ output_block_ptr = tl.make_block_ptr(
113
+ base=output_ptr,
114
+ shape=(M, N),
115
+ strides=(output_stride_0, output_stride_1),
116
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
117
+ block_shape=(BLOCK_M, BLOCK_N),
118
+ order=(1, 0),
119
+ )
120
+ scale_output_ptr = tl.make_block_ptr(
121
+ base=output_scale_ptr,
122
+ shape=(M, SN),
123
+ strides=(s_output_stride_0, s_output_stride_1),
124
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
125
+ block_shape=(BLOCK_M, BLOCK_SN),
126
+ order=(1, 0),
127
+ )
128
+
129
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
130
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
131
+
132
+
133
+ def fp8_quantize(x, QB, fp8type):
134
+ # Change batched 3D input to 2D
135
+ batched = False
136
+ if len(x.shape) == 3:
137
+ batched = True
138
+ BS = x.shape[0]
139
+ x = x.reshape(-1, x.shape[-1])
140
+
141
+ # defining the input and output tensor
142
+ M, N = x.shape
143
+ SN = N // QB
144
+
145
+ if isinstance(fp8type, str):
146
+ fp8type = convert_str_to_fp8[fp8type]
147
+ y = torch.empty_like(x, dtype=fp8type)
148
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
149
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
150
+
151
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
152
+
153
+ _fp8_quantize_kernel[grid](
154
+ y,
155
+ s_y,
156
+ x,
157
+ M,
158
+ N,
159
+ SN,
160
+ QB,
161
+ fp8MaxValue,
162
+ x.stride(0),
163
+ x.stride(1),
164
+ y.stride(0),
165
+ y.stride(1),
166
+ s_y.stride(0),
167
+ s_y.stride(1),
168
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
169
+ )
170
+
171
+ # Recover 2D to 3D
172
+ if batched:
173
+ y = y.reshape(BS, -1, y.shape[-1])
174
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
175
+
176
+ return y, s_y
llava/model/coat/activation/real_quantization/_quantize_pertensor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from ._division import fp8_division
31
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
32
+
33
+ """Per Tensor Quantize Operator"""
34
+ """Input uses full precision"""
35
+ """Output uses per tensor quantization"""
36
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
37
+
38
+
39
+ @triton.autotune(
40
+ configs=[] + get_configs_io_block(),
41
+ key=[
42
+ "N",
43
+ ],
44
+ )
45
+ @triton.heuristics(
46
+ {
47
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
48
+ }
49
+ )
50
+ @triton.jit
51
+ def _fp8_quantize_pertensor_kernel(
52
+ output_scale_ptr, # output
53
+ input_ptr, # input
54
+ M,
55
+ N,
56
+ SN,
57
+ QB: tl.constexpr,
58
+ fp8_max, # shape
59
+ input_stride_0,
60
+ input_stride_1, # input stride
61
+ s_output_stride_0,
62
+ s_output_stride_1, # scale of output stride
63
+ SCALE_MIN_THRES: tl.constexpr,
64
+ BLOCK_M: tl.constexpr,
65
+ BLOCK_N: tl.constexpr,
66
+ BLOCK_SN: tl.constexpr,
67
+ ): # CUDA block size
68
+
69
+ # Block PID
70
+ pid = tl.program_id(0)
71
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
72
+ pid_dim0 = pid // NUM_BLOCK_N
73
+ pid_dim1 = pid % NUM_BLOCK_N
74
+
75
+ # pointers
76
+ input_block_ptr = tl.make_block_ptr(
77
+ base=input_ptr,
78
+ shape=(M, N),
79
+ strides=(input_stride_0, input_stride_1),
80
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
81
+ block_shape=(BLOCK_M, BLOCK_N),
82
+ order=(1, 0),
83
+ )
84
+
85
+ input = tl.load(input_block_ptr)
86
+ input = input.to(tl.float32)
87
+
88
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
89
+
90
+ # Quantize Scale calculation
91
+ abs_output = tl.abs(output)
92
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
93
+ scale_output = max_val / fp8_max
94
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
95
+
96
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
97
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
98
+
99
+ scale_output_ptr = tl.make_block_ptr(
100
+ base=output_scale_ptr,
101
+ shape=(M, SN),
102
+ strides=(s_output_stride_0, s_output_stride_1),
103
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
104
+ block_shape=(BLOCK_M, BLOCK_SN),
105
+ order=(1, 0),
106
+ )
107
+
108
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
109
+
110
+
111
+ def fp8_quantize_pertensor(x, QB, fp8type, stochastic=False):
112
+ # Change batched 3D input to 2D
113
+ batched = False
114
+ if len(x.shape) == 3:
115
+ batched = True
116
+ BS = x.shape[0]
117
+ x = x.reshape(-1, x.shape[-1])
118
+
119
+ # defining the input and output tensor
120
+ M, N = x.shape
121
+ SN = N // QB
122
+
123
+ fp8type = convert_str_to_fp8[fp8type]
124
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
125
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
126
+
127
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
128
+
129
+ _fp8_quantize_pertensor_kernel[grid](
130
+ s_y,
131
+ x,
132
+ M,
133
+ N,
134
+ SN,
135
+ QB,
136
+ fp8MaxValue,
137
+ x.stride(0),
138
+ x.stride(1),
139
+ s_y.stride(0),
140
+ s_y.stride(1),
141
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
142
+ )
143
+
144
+ s_y_max = s_y.max()
145
+ y, s_y_max = fp8_division(x, QB, fp8type, s_y_max, stochastic=stochastic) # reuse the floating point output y1
146
+
147
+ # Recover 2D to 3D
148
+ if batched:
149
+ y = y.reshape(BS, -1, y.shape[-1])
150
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
151
+
152
+ return y, s_y_max, s_y
llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from ._division_transpose import fp8_division_transpose
31
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
32
+
33
+ """Per Tensor Quantize and Transpose Operator"""
34
+ """Input uses floating point tensor"""
35
+ """Output uses per-tensor quantization, returns a non-transpose version and a transpose version"""
36
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
37
+
38
+
39
+ @triton.autotune(
40
+ configs=[] + get_configs_io_block(),
41
+ key=[
42
+ "N",
43
+ ],
44
+ )
45
+ @triton.heuristics(
46
+ {
47
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
48
+ }
49
+ )
50
+ @triton.jit
51
+ def _fp8_quantize_pertensor_transpose_kernel(
52
+ output_scale_ptr, # output
53
+ input_ptr, # input
54
+ M,
55
+ N,
56
+ SN,
57
+ QB: tl.constexpr,
58
+ fp8_max, # shape
59
+ input_stride_0,
60
+ input_stride_1, # input stride
61
+ s_output_stride_0,
62
+ s_output_stride_1, # scale of output stride
63
+ SCALE_MIN_THRES: tl.constexpr,
64
+ BLOCK_M: tl.constexpr,
65
+ BLOCK_N: tl.constexpr,
66
+ BLOCK_SN: tl.constexpr,
67
+ ): # CUDA block size
68
+
69
+ # Block PID
70
+ pid = tl.program_id(0)
71
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
72
+ pid_dim0 = pid // NUM_BLOCK_N
73
+ pid_dim1 = pid % NUM_BLOCK_N
74
+
75
+ # pointers
76
+ input_block_ptr = tl.make_block_ptr(
77
+ base=input_ptr,
78
+ shape=(M, N),
79
+ strides=(input_stride_0, input_stride_1),
80
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
81
+ block_shape=(BLOCK_M, BLOCK_N),
82
+ order=(1, 0),
83
+ )
84
+
85
+ input = tl.load(input_block_ptr)
86
+ input = input.to(tl.float32)
87
+
88
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
89
+
90
+ # Quantize Scale calculation
91
+ abs_output = tl.abs(output)
92
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
93
+ scale_output = max_val / fp8_max
94
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
95
+
96
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
97
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
98
+
99
+ scale_output_ptr = tl.make_block_ptr(
100
+ base=output_scale_ptr,
101
+ shape=(M, SN),
102
+ strides=(s_output_stride_0, s_output_stride_1),
103
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
104
+ block_shape=(BLOCK_M, BLOCK_SN),
105
+ order=(1, 0),
106
+ )
107
+
108
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
109
+
110
+
111
+ def fp8_quantize_pertensor_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False):
112
+ # Change batched 3D input to 2D
113
+ batched = False
114
+ if len(x.shape) == 3:
115
+ batched = True
116
+ BS = x.shape[0]
117
+ x = x.reshape(-1, x.shape[-1])
118
+
119
+ # defining the input and output tensor
120
+ M, N = x.shape
121
+ SN = N // QB
122
+
123
+ fp8type = convert_str_to_fp8[fp8type]
124
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
125
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
126
+
127
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
128
+
129
+ _fp8_quantize_pertensor_transpose_kernel[grid](
130
+ s_y,
131
+ x,
132
+ M,
133
+ N,
134
+ SN,
135
+ QB,
136
+ fp8MaxValue,
137
+ x.stride(0),
138
+ x.stride(1),
139
+ s_y.stride(0),
140
+ s_y.stride(1),
141
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
142
+ )
143
+
144
+ s_y_max = s_y.max()
145
+ qy, s_y_max, qy_t = fp8_division_transpose(
146
+ x, QB, fp8type, s_y_max, stochastic=stochastic
147
+ ) # Stochastic Rounding happens here
148
+
149
+ # Recover 2D to 3D
150
+ if batched:
151
+ qy = qy.reshape(BS, -1, qy.shape[-1])
152
+ if not transpose_output_2d:
153
+ qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1])
154
+
155
+ return qy, s_y_max, qy_t # y_t is expected to be 2D tensor
llava/model/coat/activation/real_quantization/_transpose.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from .common import get_configs_io_block
31
+
32
+ """Quantize Operator"""
33
+ """Input uses 1 * 16 group quantization"""
34
+ """Output uses 1 * 16 group quantization"""
35
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
36
+
37
+
38
+ @triton.autotune(
39
+ configs=[] + get_configs_io_block(),
40
+ key=[
41
+ "N",
42
+ ],
43
+ )
44
+ @triton.jit
45
+ def _fp8_transpose_kernel(
46
+ output_ptr, # output
47
+ input_ptr, # input
48
+ M,
49
+ N, # shape
50
+ input_stride_0,
51
+ input_stride_1, # input stride
52
+ output_stride_0,
53
+ output_stride_1, # output stride
54
+ BLOCK_M: tl.constexpr,
55
+ BLOCK_N: tl.constexpr,
56
+ ): # CUDA block size
57
+
58
+ # Block PID
59
+ pid = tl.program_id(0)
60
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
61
+ pid_dim0 = pid // NUM_BLOCK_N
62
+ pid_dim1 = pid % NUM_BLOCK_N
63
+
64
+ # pointers
65
+ input_block_ptr = tl.make_block_ptr(
66
+ base=input_ptr,
67
+ shape=(M, N),
68
+ strides=(input_stride_0, input_stride_1),
69
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
70
+ block_shape=(BLOCK_M, BLOCK_N),
71
+ order=(1, 0),
72
+ )
73
+
74
+ input = tl.load(input_block_ptr)
75
+
76
+ output = tl.trans(input)
77
+
78
+ # pointers
79
+ output_block_ptr = tl.make_block_ptr(
80
+ base=output_ptr,
81
+ shape=(N, M),
82
+ strides=(output_stride_0, output_stride_1),
83
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
84
+ block_shape=(BLOCK_N, BLOCK_M),
85
+ order=(1, 0),
86
+ )
87
+
88
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
89
+
90
+
91
+ def fp8_transpose(x, transpose_output_2d=False):
92
+ # Change batched 3D input to 2D
93
+ batched = False
94
+ if len(x.shape) == 3:
95
+ batched = True
96
+ BS = x.shape[0]
97
+ x = x.reshape(-1, x.shape[-1])
98
+
99
+ # defining the input and output tensor
100
+ M, N = x.shape
101
+
102
+ y = torch.empty((N, M), dtype=x.dtype, device=x.device)
103
+
104
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
105
+
106
+ _fp8_transpose_kernel[grid](
107
+ y,
108
+ x,
109
+ M,
110
+ N,
111
+ x.stride(0),
112
+ x.stride(1),
113
+ y.stride(0),
114
+ y.stride(1),
115
+ )
116
+
117
+ # Recover 2D to 3D
118
+ if batched and not transpose_output_2d:
119
+ y = y.reshape(BS, -1, y.shape[-1])
120
+
121
+ return y
llava/model/coat/activation/real_quantization/add_bwd.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from ._division import fp8_division
31
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
32
+
33
+ """Element-wise Add, useful for backward"""
34
+ """Input1 (Residual) uses full-precision/BF16"""
35
+ """Input2 (Backbone) uses full-precision/BF16"""
36
+ """Output1 uses full-precision/BF16"""
37
+ """Output2 uses per-tensor quantization"""
38
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
39
+
40
+
41
+ @triton.autotune(
42
+ configs=[] + get_configs_io_block(),
43
+ key=[
44
+ "N",
45
+ ],
46
+ )
47
+ @triton.heuristics(
48
+ {
49
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
50
+ }
51
+ )
52
+ @triton.jit
53
+ def _fp8_add_Ifp_Ifp_Ofp_Opt_kernel(
54
+ output1_ptr, # output
55
+ output2_scale_ptr,
56
+ input1_ptr, # input
57
+ input2_ptr, # input
58
+ M,
59
+ N,
60
+ SN,
61
+ QB: tl.constexpr,
62
+ fp8_max, # shape
63
+ input1_stride_0,
64
+ input1_stride_1, # input1 stride
65
+ input2_stride_0,
66
+ input2_stride_1, # input2 stride
67
+ output1_stride_0,
68
+ output1_stride_1, # output stride
69
+ s_output2_stride_0,
70
+ s_output2_stride_1, # scale of output stride
71
+ SCALE_MIN_THRES: tl.constexpr,
72
+ BLOCK_M: tl.constexpr,
73
+ BLOCK_N: tl.constexpr,
74
+ BLOCK_SN: tl.constexpr,
75
+ ): # CUDA block size
76
+
77
+ # Block PID
78
+ pid = tl.program_id(0)
79
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
80
+ pid_dim0 = pid // NUM_BLOCK_N
81
+ pid_dim1 = pid % NUM_BLOCK_N
82
+
83
+ # --- The first input ---
84
+ input1_block_ptr = tl.make_block_ptr(
85
+ base=input1_ptr,
86
+ shape=(M, N),
87
+ strides=(input1_stride_0, input1_stride_1),
88
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
89
+ block_shape=(BLOCK_M, BLOCK_N),
90
+ order=(1, 0),
91
+ )
92
+
93
+ input1 = tl.load(input1_block_ptr)
94
+ input1 = input1.to(tl.float32)
95
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
96
+
97
+ # --- The second input ---
98
+ input2_block_ptr = tl.make_block_ptr(
99
+ base=input2_ptr,
100
+ shape=(M, N),
101
+ strides=(input2_stride_0, input2_stride_1),
102
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
103
+ block_shape=(BLOCK_M, BLOCK_N),
104
+ order=(1, 0),
105
+ )
106
+
107
+ input2 = tl.load(input2_block_ptr)
108
+ input2 = input2.to(tl.float32)
109
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
110
+
111
+ # Actual Calculation of Add
112
+ add_output = input1 + input2
113
+
114
+ # Quantize the grad 1 - Scale calculation
115
+ abs_add_output = tl.abs(add_output)
116
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
117
+ scale_output2 = max_val / fp8_max
118
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
119
+
120
+ # save the fp add output
121
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
122
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
123
+
124
+ # pointers
125
+ output1_block_ptr = tl.make_block_ptr(
126
+ base=output1_ptr,
127
+ shape=(M, N),
128
+ strides=(output1_stride_0, output1_stride_1),
129
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
130
+ block_shape=(BLOCK_M, BLOCK_N),
131
+ order=(1, 0),
132
+ )
133
+
134
+ tl.store(output1_block_ptr, fp_add_output, boundary_check=(0, 1))
135
+
136
+ # Quantize
137
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
138
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
139
+
140
+ # pointers
141
+ scale_output2_ptr = tl.make_block_ptr(
142
+ base=output2_scale_ptr,
143
+ shape=(M, SN),
144
+ strides=(s_output2_stride_0, s_output2_stride_1),
145
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
146
+ block_shape=(BLOCK_M, BLOCK_SN),
147
+ order=(1, 0),
148
+ )
149
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
150
+
151
+
152
+ def fp8_add_Ifp_Ifp_Ofp_Opt(x1, x2, QB, fp8type, stochastic=False): # suppose x1 is full precision or BF16
153
+ # Change batched 3D input to 2D
154
+ batched = False
155
+ if len(x1.shape) == 3:
156
+ assert len(x2.shape) == 3
157
+ batched = True
158
+ BS = x1.shape[0]
159
+ x1 = x1.reshape(-1, x1.shape[-1])
160
+ x2 = x2.reshape(-1, x2.shape[-1])
161
+
162
+ # defining the input and output tensor
163
+ M, N = x1.shape
164
+ SN = N // QB
165
+ assert x1.shape == x2.shape
166
+
167
+ if isinstance(fp8type, str):
168
+ fp8type = convert_str_to_fp8[fp8type]
169
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
170
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
171
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
172
+
173
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
174
+
175
+ _fp8_add_Ifp_Ifp_Ofp_Opt_kernel[grid](
176
+ y1,
177
+ s_y2,
178
+ x1,
179
+ x2,
180
+ M,
181
+ N,
182
+ SN,
183
+ QB,
184
+ fp8MaxValue,
185
+ x1.stride(0),
186
+ x1.stride(1),
187
+ x2.stride(0),
188
+ x2.stride(1),
189
+ y1.stride(0),
190
+ y1.stride(1),
191
+ s_y2.stride(0),
192
+ s_y2.stride(1),
193
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
194
+ )
195
+
196
+ s_y2_max = s_y2.max()
197
+ qy2, s_y2_max = fp8_division(y1, QB, fp8type, s_y2_max, stochastic=stochastic) # reuse the floating point output y1
198
+
199
+ # Recover 2D to 3D
200
+ if batched:
201
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
202
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
203
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
204
+
205
+ return y1, (qy2, s_y2_max, s_y2)
llava/model/coat/activation/real_quantization/add_fwd.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # SPDX-License-Identifier: Apache-2.0
22
+
23
+ import torch
24
+
25
+ # 4 block
26
+ import triton
27
+ import triton.language as tl
28
+ from triton.language.extra.cuda import libdevice
29
+
30
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
31
+
32
+ """Element-wise Add, used in forward pass"""
33
+ """Input1 (Residual) uses full-precision/BF16"""
34
+ """Input2 (Backbone) uses full-precision/BF16"""
35
+ """Output1 uses full-precision/BF16"""
36
+ """Output2 uses 1 * 16 group quantization"""
37
+ """The input can be 2D or 3D, but the calculation is performed in 2D"""
38
+
39
+
40
+ @triton.autotune(
41
+ configs=[] + get_configs_io_block(),
42
+ key=[
43
+ "N",
44
+ ],
45
+ )
46
+ @triton.heuristics(
47
+ {
48
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
49
+ }
50
+ )
51
+ @triton.jit
52
+ def _fp8_add_Ifp_Ifp_Ofp_Og16_kernel(
53
+ output1_ptr, # output
54
+ output2_ptr,
55
+ output2_scale_ptr,
56
+ input1_ptr, # input
57
+ input2_ptr, # input
58
+ M,
59
+ N,
60
+ SN,
61
+ QB: tl.constexpr,
62
+ fp8_max, # shape
63
+ input1_stride_0,
64
+ input1_stride_1, # input1 stride
65
+ input2_stride_0,
66
+ input2_stride_1, # input2 stride
67
+ output1_stride_0,
68
+ output1_stride_1, # output stride
69
+ output2_stride_0,
70
+ output2_stride_1, # output stride
71
+ s_output2_stride_0,
72
+ s_output2_stride_1, # scale of output stride
73
+ SCALE_MIN_THRES: tl.constexpr,
74
+ BLOCK_M: tl.constexpr,
75
+ BLOCK_N: tl.constexpr,
76
+ BLOCK_SN: tl.constexpr,
77
+ ): # CUDA block size
78
+
79
+ # Block PID
80
+ pid = tl.program_id(0)
81
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
82
+ pid_dim0 = pid // NUM_BLOCK_N
83
+ pid_dim1 = pid % NUM_BLOCK_N
84
+
85
+ # --- The first input ---
86
+ input1_block_ptr = tl.make_block_ptr(
87
+ base=input1_ptr,
88
+ shape=(M, N),
89
+ strides=(input1_stride_0, input1_stride_1),
90
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
91
+ block_shape=(BLOCK_M, BLOCK_N),
92
+ order=(1, 0),
93
+ )
94
+
95
+ input1 = tl.load(input1_block_ptr)
96
+ input1 = input1.to(tl.float32)
97
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
98
+
99
+ # --- The second input ---
100
+ input2_block_ptr = tl.make_block_ptr(
101
+ base=input2_ptr,
102
+ shape=(M, N),
103
+ strides=(input2_stride_0, input2_stride_1),
104
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
105
+ block_shape=(BLOCK_M, BLOCK_N),
106
+ order=(1, 0),
107
+ )
108
+
109
+ input2 = tl.load(input2_block_ptr)
110
+ input2 = input2.to(tl.float32)
111
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
112
+
113
+ # Actual Calculation of Add
114
+ add_output = input1 + input2
115
+
116
+ # Quantize the grad 1 - Scale calculation
117
+ abs_add_output = tl.abs(add_output)
118
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
119
+ scale_output2 = max_val / fp8_max
120
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
121
+
122
+ # save the fp add output
123
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
124
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
125
+
126
+ # pointers
127
+ output1_block_ptr = tl.make_block_ptr(
128
+ base=output1_ptr,
129
+ shape=(M, N),
130
+ strides=(output1_stride_0, output1_stride_1),
131
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
132
+ block_shape=(BLOCK_M, BLOCK_N),
133
+ order=(1, 0),
134
+ )
135
+
136
+ tl.store(output1_block_ptr, fp_add_output)
137
+
138
+ # Quantize
139
+ add_output = tl.fdiv(add_output, scale_output2)
140
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
141
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
142
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
143
+
144
+ add_output = add_output.to(output2_ptr.type.element_ty)
145
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
146
+
147
+ # pointers
148
+ output2_block_ptr = tl.make_block_ptr(
149
+ base=output2_ptr,
150
+ shape=(M, N),
151
+ strides=(output2_stride_0, output2_stride_1),
152
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
153
+ block_shape=(BLOCK_M, BLOCK_N),
154
+ order=(1, 0),
155
+ )
156
+ scale_output2_ptr = tl.make_block_ptr(
157
+ base=output2_scale_ptr,
158
+ shape=(M, SN),
159
+ strides=(s_output2_stride_0, s_output2_stride_1),
160
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
161
+ block_shape=(BLOCK_M, BLOCK_SN),
162
+ order=(1, 0),
163
+ )
164
+ tl.store(output2_block_ptr, add_output, boundary_check=(0, 1))
165
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
166
+
167
+
168
+ def fp8_add_Ifp_Ifp_Ofp_Og16(x1, x2, fp8type, QB): # suppose x1 is full precision or BF16
169
+ # Change batched 3D input to 2D
170
+ batched = False
171
+ if len(x1.shape) == 3:
172
+ batched = True
173
+ BS = x1.shape[0]
174
+ x1 = x1.reshape(-1, x1.shape[-1])
175
+ x2 = x2.reshape(-1, x2.shape[-1])
176
+
177
+ # defining the input and output tensor
178
+ M, N = x1.shape
179
+ SN = int(N / QB) # assume the shape of quantization block size is always 1 * G
180
+ assert x1.shape == x2.shape
181
+
182
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
183
+ y2 = torch.empty_like(x2, dtype=fp8type)
184
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
185
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
186
+
187
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
188
+
189
+ _fp8_add_Ifp_Ifp_Ofp_Og16_kernel[grid](
190
+ y1,
191
+ y2,
192
+ s_y2,
193
+ x1,
194
+ x2,
195
+ M,
196
+ N,
197
+ SN,
198
+ QB,
199
+ fp8MaxValue,
200
+ x1.stride(0),
201
+ x1.stride(1),
202
+ x2.stride(0),
203
+ x2.stride(1),
204
+ y1.stride(0),
205
+ y1.stride(1),
206
+ y2.stride(0),
207
+ y2.stride(1),
208
+ s_y2.stride(0),
209
+ s_y2.stride(1),
210
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
211
+ )
212
+
213
+ # Recover 2D to 3D
214
+ if batched:
215
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
216
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
217
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
218
+
219
+ return y1, (y2, s_y2)