Spaces:
Paused
Paused
File size: 23,550 Bytes
2f5127c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 |
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of https://github.com/pytorch/torchtune.
import warnings
import psutil
import torch
from torch import nn
from torch.autograd.graph import saved_tensors_hooks
class OffloadActivations(saved_tensors_hooks):
"""
Context manager under which activation tensors created in the forward pass will be offloaded.
Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size`
bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining
the activation on GPU VRAM throughout the program.
This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and
CPU, which is intended to overlap with the default computation stream to improve runtime. We designed
synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage.
Args:
use_pin_memory (`bool`, *optional*, defaults to `True`):
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
be moved back onto GPU more quickly but is a limited resource.
use_streams (`bool`, *optional*, defaults to `True`):
Whether to use streams for performance optimization where the communications get overlapped with the
computation. Requires a torch build after torch-2.5.0.
min_offload_size (`int`, *optional*, defaults to `1024`):
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
do not want to waste bandwidth and resources moving it to CPU and back.
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
runtime.
Raises:
ValueError: if `max_fwd_stash_size` is not at least `1`.
Example:
>>> with OffloadActivations():
>>> outputs = model(inputs, labels=labels)
>>> loss = outputs.loss
>>> loss.backward()
"""
def __init__(
self,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
) -> None:
self.use_streams = use_streams
self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors
self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
self.tensor_id = 0
self.is_first_forward_call = True
self.is_first_backward_call = True
self.is_first_forward_pass = True
# Managing cpu memory
self.use_pin_memory = use_pin_memory
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
self.accelerator_type = (
torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
)
# NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead
self.s0 = (
torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream()
) # comp stream
# For streaming
if self.use_streams:
self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream
self.fwd_stash = {} # tensor_id => (activation, ev1)
if max_fwd_stash_size < 1:
raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}")
self.max_fwd_stash_size = max_fwd_stash_size
self.bwd_tensor_stash = {} # tensor_id => activation
self.bwd_ev_stash = {} # tensor_id => ev0
self.curr_graph_id = None
self.curr_autograd_node = None
# -------- platform util functions -------- #
def verify_sufficient_virtual_memory():
curr_pct = get_cpu_ram_pct()
if curr_pct > self.virtual_memory_safe_pct:
warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used")
def get_cpu_ram_pct() -> float:
# get the percentage of memory used by the system
return psutil.virtual_memory().percent
def get_tensor_id() -> int:
# create a unique id for each tensor we are managing
self.tensor_id += 1
return self.tensor_id
def get_num_bytes_tensor(x: torch.Tensor) -> int:
# get the number of bytes in a tensor, for memory management purposes
return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes()
# -------- core pack / unpack work -------- #
def pack_tensor(activation: torch.Tensor) -> int:
# activations are passed in during forward pass - from here we take over and return a unique id
if self.is_first_forward_call:
if len(self.tracker) != 0:
raise ValueError("Backward pass should have cleared tracker of all tensors")
# set training phase trackers
self.is_first_forward_call = False
self.is_first_backward_call = True
# query for basic tensor info
num_bytes = get_num_bytes_tensor(activation)
tensor_id = get_tensor_id()
# only offload hefty bois if they're activations on CUDA (our heuristic
# for that is to check if they're not params or buffers)!
if (
activation.device.type in ["cuda", "xpu"]
and num_bytes >= self.min_tensor_size_bytes
and (
not isinstance(activation, torch.nn.Parameter)
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
)
):
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for id in list(self.fwd_stash.keys()):
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
else:
break
# Sync in, offload, and add an event to sync back later
self.s1.wait_stream(self.s0)
stream = self.s1 if self.use_streams else self.s0
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
cpu_tensor.copy_(activation, non_blocking=True)
self.tracker[tensor_id] = (
cpu_tensor,
True, # True = (in future) modified
)
if self.use_streams:
event = self.s1.record_event()
# Stash to keep activation alive til s1 is done
self.fwd_stash[tensor_id] = (activation, event)
else:
self.tracker[tensor_id] = (
activation,
False,
) # False = not modified, tensor is as is
return tensor_id
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
# backward pass - we are called with the tensor_id, which
# we will use to retrieve the saved/offloaded tensor
if self.is_first_backward_call:
if self.is_first_forward_pass:
self.is_first_forward_pass = False
if self.use_pin_memory:
verify_sufficient_virtual_memory()
self.is_first_backward_call = False
self.is_first_forward_call = True
if unpack_tensor_id not in self.tracker:
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
if modified:
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
maybe_accelerator_tensor = accelerator_tensor
# clear tensor from tracking
del self.tracker[unpack_tensor_id]
return maybe_accelerator_tensor
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
# backward pass - we are called with the tensor_id, which
# we will use to retrieve the saved/offloaded tensor
if self.is_first_backward_call:
self.curr_graph_id = torch._C._current_graph_task_id()
def wait_and_del_remaining_references() -> None:
for id in list(self.bwd_tensor_stash.keys()):
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
# Register a callback to the end of autograd to clean everything up
torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references)
if self.is_first_forward_pass:
self.is_first_forward_pass = False
if self.use_pin_memory:
verify_sufficient_virtual_memory()
self.is_first_backward_call = False
self.is_first_forward_call = True
if unpack_tensor_id not in self.tracker:
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
if modified:
# Get data on the current autograd node
graph_id = torch._C._current_graph_task_id()
node = torch._C._current_autograd_node()
prev_node_ids = []
# If we're on a new node, mark prev node's tensors to be freed later
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
self.curr_autograd_node = node
prev_node_ids = list(self.bwd_tensor_stash.keys())
brought_back_from_cpu = True
if unpack_tensor_id in self.fwd_stash:
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
brought_back_from_cpu = False
else:
# Kick off the process to bring tensors back
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1):
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
maybe_accelerator_tensor = accelerator_tensor
# Tell comp stream to wait for the info to be loaded before executing
self.s0.wait_stream(self.s1)
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an
# initial count to compare to later, during the post-hook of the
# backward node, when we need to decide whether we're allowed to free
# the tensor yet. In what obscure cases must we delay freeing the
# tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked
# tensor.
# 2. In the case that this unpacked tensor will be used in a
# checkpointed region, if one of the recomputed saved tensors ends
# up as a view of the unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the
# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata)
def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of
# the tensor has been saved as a part of checkpoint's recompute
# process, OR the user has abusedly incurred a reference on the
# unpacked tensor, THEN the tensor might be used later and we
# cannot presume to delete it after only the current node is
# done! So we use our frenemy, record_stream, to ensure the
# Tensor stays unmessed with until it's done getting used in the
# compute stream (s0 here). Note that the con here is we introduce
# non-deterministic (thus higher) memory usage, but this case
# should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount:
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
else:
event = self.s0.record_event()
self.bwd_ev_stash[unpack_tensor_id] = event
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
for id in list(self.fwd_stash.keys()):
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
# wait on prev node's events and del those
for id in prev_node_ids:
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
return outputs
node.register_hook(hook)
# clear tensor from tracking
del self.tracker[unpack_tensor_id]
return maybe_accelerator_tensor
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
super().__init__(pack_tensor, unpack_tensor)
class NoOpManager(saved_tensors_hooks):
"""
A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies
on the behavior that only the most recently registered `saved_tensors_hook` will run.
One example usage is to opt a local region of code out of activations offloading, which is usually applied globally
to best track state.
"""
def __init__(self) -> None:
def noop(tensor):
return tensor
super().__init__(noop, noop)
def get_act_offloading_ctx_manager(
model: nn.Module,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
warn_if_no_head: bool = True,
) -> OffloadActivations:
"""
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
be offloaded.
If activation offloading is enabled, we return the OffloadActivations context manager.
If activation offloading is disabled, we return a NoOpManager context manager.
Args:
model (`nn.Module`):
Model to wrap with the activation offloading context manager.
use_pin_memory (`bool`, *optional*, defaults to `True`):
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
be moved back onto GPU more quickly but is a limited resource.
use_streams (`bool`, *optional*, defaults to `True`):
Whether to use streams for performance optimization where the communications get overlapped with the
computation. Requires a torch build after torch-2.5.0.
min_offload_size (`int`, *optional*, defaults to `1024`):
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
do not want to waste bandwidth and resources moving it to CPU and back.
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
runtime.
warn_if_no_head (`bool`, *optional*, defaults to `True`):
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
head is detected.
Returns:
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,
min_offload_size=min_offload_size,
max_fwd_stash_size=max_fwd_stash_size,
)
# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive.
output_head_detected = False
noop_ctx = NoOpManager()
# Try to get the actual model if it's wrapped
unwrapped_model = model
if hasattr(unwrapped_model, "module"):
unwrapped_model = unwrapped_model.module
# check for PEFT models
if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"):
unwrapped_model = unwrapped_model.base_model
# Check for different types of output heads
if hasattr(unwrapped_model, "output"):
if isinstance(unwrapped_model.output, nn.Module):
unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module):
unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for HuggingFace model output heads
elif hasattr(unwrapped_model, "lm_head"):
unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for decoder-based models
elif hasattr(unwrapped_model, "decoder"):
decoder = unwrapped_model.decoder
if hasattr(decoder, "output"):
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Some models have lm_head in the decoder
elif hasattr(decoder, "lm_head"):
decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for transformer models with final layer norm
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"):
final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for models with head module
elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module):
unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
if not output_head_detected and warn_if_no_head:
warnings.warn(
"During activation offloading, no output head was detected. If your model has an output head, it will be "
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
"passing `warn_if_no_head=False`."
)
# Disable offloading for any Liger modules
for name, module in unwrapped_model.named_modules():
if "liger" in name.lower():
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
return activations_handling_ctx
|