Jan Philipp Harries
Jan Philipp Harries
commited on
set fsdp state dict (#584)
Browse filesCo-authored-by: Jan Philipp Harries <[email protected]>
- src/axolotl/train.py +4 -0
src/axolotl/train.py
CHANGED
|
@@ -117,6 +117,10 @@ def train(
|
|
| 117 |
|
| 118 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
if cfg.relora_steps:
|
| 121 |
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
| 122 |
model = model.merge_and_unload()
|
|
|
|
| 117 |
|
| 118 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 119 |
|
| 120 |
+
if trainer.is_fsdp_enabled:
|
| 121 |
+
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
| 122 |
+
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
| 123 |
+
|
| 124 |
if cfg.relora_steps:
|
| 125 |
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
| 126 |
model = model.merge_and_unload()
|