Fix for check with cfg and merge_lora (#600)
Browse files
.github/workflows/tests.yml
CHANGED
|
@@ -61,7 +61,7 @@ jobs:
|
|
| 61 |
uses: actions/setup-python@v4
|
| 62 |
with:
|
| 63 |
python-version: "3.10"
|
| 64 |
-
cache: 'pip' # caching pip dependencies
|
| 65 |
|
| 66 |
- name: Install dependencies
|
| 67 |
run: |
|
|
|
|
| 61 |
uses: actions/setup-python@v4
|
| 62 |
with:
|
| 63 |
python-version: "3.10"
|
| 64 |
+
# cache: 'pip' # caching pip dependencies
|
| 65 |
|
| 66 |
- name: Install dependencies
|
| 67 |
run: |
|
src/axolotl/cli/__init__.py
CHANGED
|
@@ -70,7 +70,7 @@ def do_merge_lora(
|
|
| 70 |
model.to(dtype=torch.float16)
|
| 71 |
|
| 72 |
if cfg.local_rank == 0:
|
| 73 |
-
LOG.info("saving merged model")
|
| 74 |
model.save_pretrained(
|
| 75 |
str(Path(cfg.output_dir) / "merged"),
|
| 76 |
safe_serialization=safe_serialization,
|
|
|
|
| 70 |
model.to(dtype=torch.float16)
|
| 71 |
|
| 72 |
if cfg.local_rank == 0:
|
| 73 |
+
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
| 74 |
model.save_pretrained(
|
| 75 |
str(Path(cfg.output_dir) / "merged"),
|
| 76 |
safe_serialization=safe_serialization,
|
src/axolotl/cli/merge_lora.py
CHANGED
|
@@ -13,12 +13,12 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 13 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 14 |
# pylint: disable=duplicate-code
|
| 15 |
print_axolotl_text_art()
|
| 16 |
-
parsed_cfg = load_cfg(config, **kwargs)
|
| 17 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 18 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 19 |
return_remaining_strings=True
|
| 20 |
)
|
| 21 |
parsed_cli_args.merge_lora = True
|
|
|
|
| 22 |
|
| 23 |
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 24 |
|
|
|
|
| 13 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 14 |
# pylint: disable=duplicate-code
|
| 15 |
print_axolotl_text_art()
|
|
|
|
| 16 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 17 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 18 |
return_remaining_strings=True
|
| 19 |
)
|
| 20 |
parsed_cli_args.merge_lora = True
|
| 21 |
+
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
| 22 |
|
| 23 |
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 24 |
|