fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)
Browse files* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path
* fix: normalize config
src/axolotl/utils/config/__init__.py
CHANGED
|
@@ -119,6 +119,10 @@ def normalize_config(cfg):
|
|
| 119 |
model_config = load_model_config(cfg)
|
| 120 |
cfg.model_config_type = model_config.model_type
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
# figure out if the model is llama
|
| 123 |
cfg.is_llama_derived_model = (
|
| 124 |
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
|
|
|
| 119 |
model_config = load_model_config(cfg)
|
| 120 |
cfg.model_config_type = model_config.model_type
|
| 121 |
|
| 122 |
+
cfg.tokenizer_config = (
|
| 123 |
+
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
# figure out if the model is llama
|
| 127 |
cfg.is_llama_derived_model = (
|
| 128 |
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
src/axolotl/utils/data.py
CHANGED
|
@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
|
|
| 134 |
split="train",
|
| 135 |
) -> Tuple[DatasetDict, List[Prompter]]:
|
| 136 |
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
| 137 |
-
tokenizer_name =
|
| 138 |
ds_hash = str(
|
| 139 |
md5(
|
| 140 |
(
|
|
|
|
| 134 |
split="train",
|
| 135 |
) -> Tuple[DatasetDict, List[Prompter]]:
|
| 136 |
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
| 137 |
+
tokenizer_name = cfg.tokenizer_config
|
| 138 |
ds_hash = str(
|
| 139 |
md5(
|
| 140 |
(
|
src/axolotl/utils/models.py
CHANGED
|
@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
|
|
| 134 |
if cfg.tokenizer_type:
|
| 135 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
| 136 |
|
| 137 |
-
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
| 138 |
tokenizer = tokenizer_cls.from_pretrained(
|
| 139 |
-
tokenizer_config,
|
| 140 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 141 |
use_fast=use_fast,
|
| 142 |
**tokenizer_kwargs,
|
|
|
|
| 134 |
if cfg.tokenizer_type:
|
| 135 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
| 136 |
|
|
|
|
| 137 |
tokenizer = tokenizer_cls.from_pretrained(
|
| 138 |
+
cfg.tokenizer_config,
|
| 139 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 140 |
use_fast=use_fast,
|
| 141 |
**tokenizer_kwargs,
|
tests/core/test_trainer_builder.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
unit tests for axolotl.core.trainer_builder
|
| 3 |
"""
|
|
|
|
| 4 |
import pytest
|
| 5 |
|
| 6 |
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
|
|
|
| 7 |
from axolotl.utils.dict import DictDefault
|
| 8 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 9 |
|
| 10 |
|
| 11 |
@pytest.fixture(name="cfg")
|
| 12 |
def fixture_cfg():
|
| 13 |
-
|
| 14 |
{
|
| 15 |
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
| 16 |
"model_type": "AutoModelForCausalLM",
|
|
@@ -34,6 +36,10 @@ def fixture_cfg():
|
|
| 34 |
}
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
@pytest.fixture(name="tokenizer")
|
| 39 |
def fixture_tokenizer(cfg):
|
|
|
|
| 1 |
"""
|
| 2 |
unit tests for axolotl.core.trainer_builder
|
| 3 |
"""
|
| 4 |
+
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
| 8 |
+
from axolotl.utils.config import normalize_config
|
| 9 |
from axolotl.utils.dict import DictDefault
|
| 10 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 11 |
|
| 12 |
|
| 13 |
@pytest.fixture(name="cfg")
|
| 14 |
def fixture_cfg():
|
| 15 |
+
cfg = DictDefault(
|
| 16 |
{
|
| 17 |
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
| 18 |
"model_type": "AutoModelForCausalLM",
|
|
|
|
| 36 |
}
|
| 37 |
)
|
| 38 |
|
| 39 |
+
normalize_config(cfg)
|
| 40 |
+
|
| 41 |
+
return cfg
|
| 42 |
+
|
| 43 |
|
| 44 |
@pytest.fixture(name="tokenizer")
|
| 45 |
def fixture_tokenizer(cfg):
|