fix(config): Set eos/bos to tokenizer if different (#801)
Browse files* fix(config): Set eos/bos to tokenizer if different
* chore: fix lint
- src/axolotl/utils/models.py +14 -0
src/axolotl/utils/models.py
CHANGED
|
@@ -386,6 +386,20 @@ def load_model(
|
|
| 386 |
)
|
| 387 |
model.config.max_position_embeddings = cfg.sequence_len
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
if model.device.type == "cuda":
|
| 390 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 391 |
|
|
|
|
| 386 |
)
|
| 387 |
model.config.max_position_embeddings = cfg.sequence_len
|
| 388 |
|
| 389 |
+
if (
|
| 390 |
+
hasattr(model.config, "bos_token_id")
|
| 391 |
+
and model.config.bos_token_id
|
| 392 |
+
and model.config.bos_token_id != tokenizer.bos_token_id
|
| 393 |
+
):
|
| 394 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 395 |
+
|
| 396 |
+
if (
|
| 397 |
+
hasattr(model.config, "eos_token_id")
|
| 398 |
+
and model.config.eos_token_id
|
| 399 |
+
and model.config.eos_token_id != tokenizer.eos_token_id
|
| 400 |
+
):
|
| 401 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 402 |
+
|
| 403 |
if model.device.type == "cuda":
|
| 404 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 405 |
|