log GPU memory usage
Browse files- requirements.txt +1 -0
- scripts/finetune.py +3 -0
- src/axolotl/utils/bench.py +23 -0
- src/axolotl/utils/callbacks.py +27 -0
- src/axolotl/utils/models.py +7 -0
- src/axolotl/utils/trainer.py +2 -0
requirements.txt
CHANGED
|
@@ -19,3 +19,4 @@ evaluate==0.4.0
|
|
| 19 |
rouge-score==0.1.2
|
| 20 |
scipy
|
| 21 |
scikit-learn==1.2.2
|
|
|
|
|
|
| 19 |
rouge-score==0.1.2
|
| 20 |
scipy
|
| 21 |
scikit-learn==1.2.2
|
| 22 |
+
nvidia-ml-py3
|
scripts/finetune.py
CHANGED
|
@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
|
| 18 |
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
from axolotl.logging_config import configure_logging
|
|
|
|
| 21 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 22 |
from axolotl.utils.dict import DictDefault
|
| 23 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
@@ -250,6 +251,8 @@ def train(
|
|
| 250 |
LOG.info("Finished preparing dataset. Exiting...")
|
| 251 |
return
|
| 252 |
|
|
|
|
|
|
|
| 253 |
# Load the model and tokenizer
|
| 254 |
LOG.info("loading model and peft_config...")
|
| 255 |
model, peft_config = load_model(
|
|
|
|
| 18 |
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
from axolotl.logging_config import configure_logging
|
| 21 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
| 22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 23 |
from axolotl.utils.dict import DictDefault
|
| 24 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
| 251 |
LOG.info("Finished preparing dataset. Exiting...")
|
| 252 |
return
|
| 253 |
|
| 254 |
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 255 |
+
|
| 256 |
# Load the model and tokenizer
|
| 257 |
LOG.info("loading model and peft_config...")
|
| 258 |
model, peft_config = load_model(
|
src/axolotl/utils/bench.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmarking and measurement utilities"""
|
| 2 |
+
|
| 3 |
+
import pynvml
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def gpu_memory_usage(device):
|
| 8 |
+
if isinstance(device, torch.device):
|
| 9 |
+
device = device.index
|
| 10 |
+
if isinstance(device, str) and device.startswith("cuda:"):
|
| 11 |
+
device = int(device[5:])
|
| 12 |
+
|
| 13 |
+
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
| 14 |
+
pynvml.nvmlInit()
|
| 15 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 16 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 17 |
+
return info.used / 1024.0**3
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def log_gpu_memory_usage(log, msg, device):
|
| 21 |
+
log.info(
|
| 22 |
+
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
| 23 |
+
)
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Callbacks for Trainer class"""
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
from optimum.bettertransformer import BetterTransformer
|
|
@@ -11,6 +12,10 @@ from transformers import (
|
|
| 11 |
)
|
| 12 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
| 16 |
"""Callback to save the PEFT adapter"""
|
|
@@ -67,3 +72,25 @@ class SaveBetterTransformerModelCallback(
|
|
| 67 |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
| 68 |
control.should_save = False
|
| 69 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Callbacks for Trainer class"""
|
| 2 |
|
| 3 |
+
import logging
|
| 4 |
import os
|
| 5 |
|
| 6 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
| 12 |
)
|
| 13 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
| 14 |
|
| 15 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
| 16 |
+
|
| 17 |
+
LOG = logging.getLogger("axolotl.callbacks")
|
| 18 |
+
|
| 19 |
|
| 20 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
| 21 |
"""Callback to save the PEFT adapter"""
|
|
|
|
| 72 |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
| 73 |
control.should_save = False
|
| 74 |
return control
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PrintGPUStatsCallback(
|
| 78 |
+
TrainerCallback
|
| 79 |
+
): # pylint: disable=too-few-public-methods disable=unused-argument
|
| 80 |
+
"""Callback to print GPU utilization"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, cfg):
|
| 83 |
+
self.cfg = cfg
|
| 84 |
+
self.logged = False
|
| 85 |
+
|
| 86 |
+
def on_step_end(
|
| 87 |
+
self,
|
| 88 |
+
args: TrainingArguments,
|
| 89 |
+
state: TrainerState,
|
| 90 |
+
control: TrainerControl,
|
| 91 |
+
**kwargs,
|
| 92 |
+
):
|
| 93 |
+
if not self.logged:
|
| 94 |
+
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
| 95 |
+
self.logged = True
|
| 96 |
+
return control
|
src/axolotl/utils/models.py
CHANGED
|
@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
|
|
| 22 |
)
|
| 23 |
|
| 24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
|
|
| 25 |
|
| 26 |
LOG = logging.getLogger("axolotl")
|
| 27 |
|
|
@@ -324,6 +325,9 @@ def load_model(
|
|
| 324 |
)
|
| 325 |
model.config.max_position_embeddings = cfg.sequence_len
|
| 326 |
|
|
|
|
|
|
|
|
|
|
| 327 |
if not cfg.gptq and (
|
| 328 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 329 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
@@ -360,6 +364,9 @@ def load_model(
|
|
| 360 |
module.scales = module.scales.half()
|
| 361 |
module.bias = module.bias.half()
|
| 362 |
|
|
|
|
|
|
|
|
|
|
| 363 |
if (
|
| 364 |
torch.cuda.device_count() > 1
|
| 365 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 25 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
| 26 |
|
| 27 |
LOG = logging.getLogger("axolotl")
|
| 28 |
|
|
|
|
| 325 |
)
|
| 326 |
model.config.max_position_embeddings = cfg.sequence_len
|
| 327 |
|
| 328 |
+
if model.device.type == "cuda":
|
| 329 |
+
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 330 |
+
|
| 331 |
if not cfg.gptq and (
|
| 332 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 333 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
|
| 364 |
module.scales = module.scales.half()
|
| 365 |
module.bias = module.bias.half()
|
| 366 |
|
| 367 |
+
if model.device.type == "cuda":
|
| 368 |
+
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
| 369 |
+
|
| 370 |
if (
|
| 371 |
torch.cuda.device_count() > 1
|
| 372 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -18,6 +18,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|
| 18 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 19 |
|
| 20 |
from axolotl.utils.callbacks import (
|
|
|
|
| 21 |
SaveBetterTransformerModelCallback,
|
| 22 |
SavePeftModelCallback,
|
| 23 |
)
|
|
@@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 292 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
| 293 |
|
| 294 |
callbacks = []
|
|
|
|
| 295 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 296 |
if cfg.early_stopping_patience:
|
| 297 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
|
| 18 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 19 |
|
| 20 |
from axolotl.utils.callbacks import (
|
| 21 |
+
PrintGPUStatsCallback,
|
| 22 |
SaveBetterTransformerModelCallback,
|
| 23 |
SavePeftModelCallback,
|
| 24 |
)
|
|
|
|
| 293 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
| 294 |
|
| 295 |
callbacks = []
|
| 296 |
+
callbacks.append(PrintGPUStatsCallback(cfg))
|
| 297 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 298 |
if cfg.early_stopping_patience:
|
| 299 |
early_stop_cb = EarlyStoppingCallback(
|