# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of https://github.com/pytorch/torchtune. import warnings import psutil import torch from torch import nn from torch.autograd.graph import saved_tensors_hooks class OffloadActivations(saved_tensors_hooks): """ Context manager under which activation tensors created in the forward pass will be offloaded. Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size` bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining the activation on GPU VRAM throughout the program. This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and CPU, which is intended to overlap with the default computation stream to improve runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage. Args: use_pin_memory (`bool`, *optional*, defaults to `True`): Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly but is a limited resource. use_streams (`bool`, *optional*, defaults to `True`): Whether to use streams for performance optimization where the communications get overlapped with the computation. Requires a torch build after torch-2.5.0. min_offload_size (`int`, *optional*, defaults to `1024`): Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we do not want to waste bandwidth and resources moving it to CPU and back. max_fwd_stash_size (`int`, *optional*, defaults to `5`): Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing runtime. Raises: ValueError: if `max_fwd_stash_size` is not at least `1`. Example: >>> with OffloadActivations(): >>> outputs = model(inputs, labels=labels) >>> loss = outputs.loss >>> loss.backward() """ def __init__( self, use_pin_memory: bool = True, use_streams: bool = True, min_offload_size: int = 1024, max_fwd_stash_size: int = 5, ) -> None: self.use_streams = use_streams self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where self.tensor_id = 0 self.is_first_forward_call = True self.is_first_backward_call = True self.is_first_forward_pass = True # Managing cpu memory self.use_pin_memory = use_pin_memory self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory self.accelerator_type = ( torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" ) # NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead self.s0 = ( torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream() ) # comp stream # For streaming if self.use_streams: self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream self.fwd_stash = {} # tensor_id => (activation, ev1) if max_fwd_stash_size < 1: raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}") self.max_fwd_stash_size = max_fwd_stash_size self.bwd_tensor_stash = {} # tensor_id => activation self.bwd_ev_stash = {} # tensor_id => ev0 self.curr_graph_id = None self.curr_autograd_node = None # -------- platform util functions -------- # def verify_sufficient_virtual_memory(): curr_pct = get_cpu_ram_pct() if curr_pct > self.virtual_memory_safe_pct: warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used") def get_cpu_ram_pct() -> float: # get the percentage of memory used by the system return psutil.virtual_memory().percent def get_tensor_id() -> int: # create a unique id for each tensor we are managing self.tensor_id += 1 return self.tensor_id def get_num_bytes_tensor(x: torch.Tensor) -> int: # get the number of bytes in a tensor, for memory management purposes return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes() # -------- core pack / unpack work -------- # def pack_tensor(activation: torch.Tensor) -> int: # activations are passed in during forward pass - from here we take over and return a unique id if self.is_first_forward_call: if len(self.tracker) != 0: raise ValueError("Backward pass should have cleared tracker of all tensors") # set training phase trackers self.is_first_forward_call = False self.is_first_backward_call = True # query for basic tensor info num_bytes = get_num_bytes_tensor(activation) tensor_id = get_tensor_id() # only offload hefty bois if they're activations on CUDA (our heuristic # for that is to check if they're not params or buffers)! if ( activation.device.type in ["cuda", "xpu"] and num_bytes >= self.min_tensor_size_bytes and ( not isinstance(activation, torch.nn.Parameter) and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)) ) ): if self.use_streams: # First, sync back and dereference previously offloaded tensors # as the offloading should be done sufficiently long ago. for id in list(self.fwd_stash.keys()): if id <= tensor_id - self.max_fwd_stash_size: _, ev = self.fwd_stash[id] self.s0.wait_event(ev) del self.fwd_stash[id] else: break # Sync in, offload, and add an event to sync back later self.s1.wait_stream(self.s0) stream = self.s1 if self.use_streams else self.s0 with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream): cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") cpu_tensor.copy_(activation, non_blocking=True) self.tracker[tensor_id] = ( cpu_tensor, True, # True = (in future) modified ) if self.use_streams: event = self.s1.record_event() # Stash to keep activation alive til s1 is done self.fwd_stash[tensor_id] = (activation, event) else: self.tracker[tensor_id] = ( activation, False, ) # False = not modified, tensor is as is return tensor_id def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: # backward pass - we are called with the tensor_id, which # we will use to retrieve the saved/offloaded tensor if self.is_first_backward_call: if self.is_first_forward_pass: self.is_first_forward_pass = False if self.use_pin_memory: verify_sufficient_virtual_memory() self.is_first_backward_call = False self.is_first_forward_call = True if unpack_tensor_id not in self.tracker: raise ValueError(f"Untracked tensor with id {unpack_tensor_id}") maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] if modified: accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) maybe_accelerator_tensor = accelerator_tensor # clear tensor from tracking del self.tracker[unpack_tensor_id] return maybe_accelerator_tensor def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: # backward pass - we are called with the tensor_id, which # we will use to retrieve the saved/offloaded tensor if self.is_first_backward_call: self.curr_graph_id = torch._C._current_graph_task_id() def wait_and_del_remaining_references() -> None: for id in list(self.bwd_tensor_stash.keys()): event = self.bwd_ev_stash[id] self.s1.wait_event(event) del self.bwd_tensor_stash[id] # Register a callback to the end of autograd to clean everything up torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references) if self.is_first_forward_pass: self.is_first_forward_pass = False if self.use_pin_memory: verify_sufficient_virtual_memory() self.is_first_backward_call = False self.is_first_forward_call = True if unpack_tensor_id not in self.tracker: raise ValueError(f"untracked tensor with id {unpack_tensor_id}") maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] if modified: # Get data on the current autograd node graph_id = torch._C._current_graph_task_id() node = torch._C._current_autograd_node() prev_node_ids = [] # If we're on a new node, mark prev node's tensors to be freed later if graph_id == self.curr_graph_id and self.curr_autograd_node != node: self.curr_autograd_node = node prev_node_ids = list(self.bwd_tensor_stash.keys()) brought_back_from_cpu = True if unpack_tensor_id in self.fwd_stash: maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0] brought_back_from_cpu = False else: # Kick off the process to bring tensors back with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1): accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) maybe_accelerator_tensor = accelerator_tensor # Tell comp stream to wait for the info to be loaded before executing self.s0.wait_stream(self.s1) # Stash the tensor to keep memory alive until compute stream is complete self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor # Note: [Track views of the unpacked] # Why do we get the use count of the unpacked tensor here? We want an # initial count to compare to later, during the post-hook of the # backward node, when we need to decide whether we're allowed to free # the tensor yet. In what obscure cases must we delay freeing the # tensor (and thus call record_stream)? # 1. Any of the outputs of the backward node is a view of the unpacked # tensor. # 2. In the case that this unpacked tensor will be used in a # checkpointed region, if one of the recomputed saved tensors ends # up as a view of the unpacked tensor. # 3. The user abuses the system somehow and manually relies on the # unpacked tensor to exist after the backward node has executed. storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata) def hook(outputs, inputs): # create events for the current node inputs/outputs if they were streamed in if brought_back_from_cpu: # See Note: [Track views of the unpacked] # IF any of the outputs is a view of the tensor, OR if a view of # the tensor has been saved as a part of checkpoint's recompute # process, OR the user has abusedly incurred a reference on the # unpacked tensor, THEN the tensor might be used later and we # cannot presume to delete it after only the current node is # done! So we use our frenemy, record_stream, to ensure the # Tensor stays unmessed with until it's done getting used in the # compute stream (s0 here). Note that the con here is we introduce # non-deterministic (thus higher) memory usage, but this case # should not happen often. unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount: unpacked_tensor.record_stream(self.s0) del self.bwd_tensor_stash[unpack_tensor_id] else: event = self.s0.record_event() self.bwd_ev_stash[unpack_tensor_id] = event # if there are still things in the fwd_stash, get rid of them as we're in bwd now for id in list(self.fwd_stash.keys()): _, ev = self.fwd_stash[id] self.s0.wait_event(ev) del self.fwd_stash[id] # wait on prev node's events and del those for id in prev_node_ids: event = self.bwd_ev_stash[id] self.s1.wait_event(event) del self.bwd_tensor_stash[id] return outputs node.register_hook(hook) # clear tensor from tracking del self.tracker[unpack_tensor_id] return maybe_accelerator_tensor unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream super().__init__(pack_tensor, unpack_tensor) class NoOpManager(saved_tensors_hooks): """ A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies on the behavior that only the most recently registered `saved_tensors_hook` will run. One example usage is to opt a local region of code out of activations offloading, which is usually applied globally to best track state. """ def __init__(self) -> None: def noop(tensor): return tensor super().__init__(noop, noop) def get_act_offloading_ctx_manager( model: nn.Module, use_pin_memory: bool = True, use_streams: bool = True, min_offload_size: int = 1024, max_fwd_stash_size: int = 5, warn_if_no_head: bool = True, ) -> OffloadActivations: """ Returns the activation offloading context manager for the model. All but the last output Linear in every step will be offloaded. If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is disabled, we return a NoOpManager context manager. Args: model (`nn.Module`): Model to wrap with the activation offloading context manager. use_pin_memory (`bool`, *optional*, defaults to `True`): Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly but is a limited resource. use_streams (`bool`, *optional*, defaults to `True`): Whether to use streams for performance optimization where the communications get overlapped with the computation. Requires a torch build after torch-2.5.0. min_offload_size (`int`, *optional*, defaults to `1024`): Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we do not want to waste bandwidth and resources moving it to CPU and back. max_fwd_stash_size (`int`, *optional*, defaults to `5`): Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing runtime. warn_if_no_head (`bool`, *optional*, defaults to `True`): Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output head is detected. Returns: `contextlib.ContextDecorator`: Activation offloading context manager for the model. """ activations_handling_ctx = OffloadActivations( use_pin_memory=use_pin_memory, use_streams=use_streams, min_offload_size=min_offload_size, max_fwd_stash_size=max_fwd_stash_size, ) # Below is our hack to disable offloading the last output Linear in every # step, as the cost for offloading the activation and then soon after bringing # it back is expensive. output_head_detected = False noop_ctx = NoOpManager() # Try to get the actual model if it's wrapped unwrapped_model = model if hasattr(unwrapped_model, "module"): unwrapped_model = unwrapped_model.module # check for PEFT models if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"): unwrapped_model = unwrapped_model.base_model # Check for different types of output heads if hasattr(unwrapped_model, "output"): if isinstance(unwrapped_model.output, nn.Module): unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module): unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True # Check for HuggingFace model output heads elif hasattr(unwrapped_model, "lm_head"): unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True # Check for decoder-based models elif hasattr(unwrapped_model, "decoder"): decoder = unwrapped_model.decoder if hasattr(decoder, "output"): decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True # Some models have lm_head in the decoder elif hasattr(decoder, "lm_head"): decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True # Check for transformer models with final layer norm elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"): final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True # Check for models with head module elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module): unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) output_head_detected = True if not output_head_detected and warn_if_no_head: warnings.warn( "During activation offloading, no output head was detected. If your model has an output head, it will be " "offloaded. This usually greatly slows training, given the large vocabulary size. To change this " "behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by " "passing `warn_if_no_head=False`." ) # Disable offloading for any Liger modules for name, module in unwrapped_model.named_modules(): if "liger" in name.lower(): module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) return activations_handling_ctx