feat: exclude mamba blocks for jamba (#1578)
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
"""Module for models and model loading"""
|
|
|
|
| 2 |
# pylint: disable=too-many-lines
|
| 3 |
|
| 4 |
import logging
|
|
@@ -504,6 +505,9 @@ def load_model(
|
|
| 504 |
bnb_config = {
|
| 505 |
"load_in_8bit": True,
|
| 506 |
}
|
|
|
|
|
|
|
|
|
|
| 507 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 508 |
**bnb_config,
|
| 509 |
)
|
|
|
|
| 1 |
"""Module for models and model loading"""
|
| 2 |
+
|
| 3 |
# pylint: disable=too-many-lines
|
| 4 |
|
| 5 |
import logging
|
|
|
|
| 505 |
bnb_config = {
|
| 506 |
"load_in_8bit": True,
|
| 507 |
}
|
| 508 |
+
# Exclude mamba blocks from int8 quantization for jamba
|
| 509 |
+
if cfg.model_config_type == "jamba":
|
| 510 |
+
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
| 511 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 512 |
**bnb_config,
|
| 513 |
)
|