|
""" |
|
This file defines a mixin class for sparse transformers that enables elastic memory management. |
|
It provides functionality to dynamically adjust memory usage by controlling gradient checkpointing |
|
across transformer blocks, allowing for trading computation for memory efficiency. |
|
""" |
|
|
|
from contextlib import contextmanager |
|
from typing import * |
|
import math |
|
from ..modules import sparse as sp |
|
from ..utils.elastic_utils import ElasticModuleMixin |
|
|
|
|
|
class SparseTransformerElasticMixin(ElasticModuleMixin): |
|
""" |
|
A mixin class for sparse transformers that provides elastic memory management capabilities. |
|
Extends the base ElasticModuleMixin with sparse tensor-specific functionality. |
|
""" |
|
|
|
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): |
|
""" |
|
Determines the input size from a sparse tensor. |
|
|
|
Args: |
|
x: A SparseTensor input |
|
*args, **kwargs: Additional arguments (unused) |
|
|
|
Returns: |
|
The size of the feature dimension of the sparse tensor |
|
""" |
|
return x.feats.shape[0] |
|
|
|
@contextmanager |
|
def with_mem_ratio(self, mem_ratio=1.0): |
|
""" |
|
Context manager that temporarily adjusts memory usage by enabling gradient checkpointing |
|
for a portion of the transformer blocks based on the specified memory ratio. |
|
|
|
Args: |
|
mem_ratio: A value between 0 and 1 indicating the desired memory ratio. |
|
1.0 means use all available memory (no checkpointing). |
|
Lower values enable more checkpointing to reduce memory usage. |
|
|
|
Yields: |
|
The exact memory ratio that could be achieved with the block granularity. |
|
""" |
|
if mem_ratio == 1.0: |
|
|
|
yield 1.0 |
|
return |
|
|
|
|
|
num_blocks = len(self.blocks) |
|
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) |
|
|
|
|
|
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks |
|
|
|
|
|
for i in range(num_blocks): |
|
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks |
|
|
|
yield exact_mem_ratio |
|
|
|
|
|
for i in range(num_blocks): |
|
self.blocks[i].use_checkpoint = False |
|
|