Spaces:
Running
Running
import contextlib | |
import functools | |
import inspect | |
from enum import Enum | |
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union | |
import torch | |
# Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take | |
# control for dispatching to different attention providers, we need to import the original function | |
# to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`. | |
import torch.autograd | |
from diffusers.utils.import_utils import OptionalDependencyNotAvailable | |
from torch.nn.functional import scaled_dot_product_attention as native_sdpa | |
from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER | |
from finetrainers.logging import get_logger | |
from finetrainers.utils.import_utils import ( | |
is_flash_attn_available, | |
is_flash_attn_version, | |
is_sageattention_available, | |
is_sageattention_version, | |
is_torch_version, | |
is_xformers_available, | |
is_xformers_version, | |
) | |
if is_flash_attn_available(): | |
if is_flash_attn_version("<", "2.6.3"): | |
raise OptionalDependencyNotAvailable( | |
"The `flash-attn` library version is too old. Please update it to at least 2.6.3." | |
) | |
from flash_attn import flash_attn_func, flash_attn_varlen_func | |
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward | |
else: | |
flash_attn_func = None | |
flash_attn_varlen_func = None | |
_flash_attn_forward = None | |
_flash_attn_backward = None | |
if is_sageattention_available(): | |
if is_sageattention_version("<", "2.1.1"): | |
raise OptionalDependencyNotAvailable( | |
"The `sageattention` library version is too old. Please update it to at least 2.1.1." | |
) | |
from sageattention import ( | |
sageattn, | |
sageattn_qk_int8_pv_fp8_cuda, | |
sageattn_qk_int8_pv_fp8_cuda_sm90, | |
sageattn_qk_int8_pv_fp16_cuda, | |
sageattn_qk_int8_pv_fp16_triton, | |
sageattn_varlen, | |
) | |
else: | |
sageattn = None | |
sageattn_qk_int8_pv_fp16_cuda = None | |
sageattn_qk_int8_pv_fp16_triton = None | |
sageattn_qk_int8_pv_fp8_cuda = None | |
sageattn_qk_int8_pv_fp8_cuda_sm90 = None | |
sageattn_varlen = None | |
if is_torch_version(">=", "2.5.0"): | |
import torch.nn.attention.flex_attention as flex_attention | |
if is_torch_version(">=", "2.6.0"): | |
from torch.distributed.tensor.experimental._attention import ( | |
_AttentionOp, | |
_cp_options, | |
_templated_ring_attention, | |
_templated_ring_attention_backward, | |
set_rotate_method, | |
) | |
else: | |
_cp_options = None | |
_templated_ring_attention = None | |
set_rotate_method = None | |
class _AttentionOp: | |
def __init__(self, *args, **kwargs): | |
raise OptionalDependencyNotAvailable( | |
"The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." | |
) | |
if is_xformers_available(): | |
if is_xformers_version("<", "0.0.29"): | |
raise OptionalDependencyNotAvailable( | |
"The `xformers` library version is too old. Please update it to at least 0.0.29." | |
) | |
import xformers.ops as xops | |
else: | |
xops = None | |
logger = get_logger() | |
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] | |
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] | |
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] | |
# ===== Custom operator implementations/wrappers ===== | |
def _finetrainers_scaled_dot_product_efficient_attention_forward( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_bias: Optional[torch.Tensor] = None, | |
compute_log_sumexp: bool = False, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
# Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 | |
# See: https://github.com/pytorch/pytorch/issues/152942 | |
seqlen_q = query.shape[-2] | |
out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_bias, | |
compute_log_sumexp=compute_log_sumexp, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
# LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch | |
# ring attention does not error out with shape mismatch | |
if compute_log_sumexp: | |
assert lse.ndim == 3 | |
lse = lse[:, :, :seqlen_q] # .contiguous() | |
return out, lse, philox_seed, philox_offset | |
# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) | |
def _finetrainers_scaled_dot_product_efficient_attention_backward( | |
grad_out_: torch.Tensor, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_bias: torch.Tensor, | |
out: torch.Tensor, | |
logsumexp: torch.Tensor, | |
philox_seed: torch.Tensor, | |
philox_offset: torch.Tensor, | |
dropout_p: float, | |
grad_input_mask: List[bool], | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
assert len(grad_input_mask) == 4 | |
# https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113 | |
kAlignLSE = 32 | |
logsumexp = torch.nn.functional.pad( | |
logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") | |
) | |
grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( | |
grad_out_=grad_out_, | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_bias, | |
out=out, | |
logsumexp=logsumexp, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
dropout_p=dropout_p, | |
grad_input_mask=grad_input_mask, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
return grad_query, grad_key, grad_value, grad_attn_bias | |
# This function wraps the actual _flash_attn_forward call to return LSE at index 1 to be compatible with pytorch's native ring attention | |
def _finetrainers_flash_attn_forward( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
dropout_p: float = 0.0, | |
scale: Optional[float] = None, | |
is_causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
return_softmax: bool = False, | |
): | |
query, key, value = ( | |
x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) | |
) # [B, N, S, D] -> [B, S, N, D] | |
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( | |
query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax | |
) | |
out = out.permute(0, 2, 1, 3).contiguous() # [B, S, N, D] -> [B, N, S, D] | |
return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state | |
# This function wraps the actual _flash_attn_backward call as the counterpart of the _finetrainers_flash_attn_forward function | |
def _finetrainers_flash_attn_backward( | |
grad_out: torch.Tensor, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
out: torch.Tensor, | |
logsumexp: torch.Tensor, # Needs a different names than the one used in flash-attn because _templated_ring_attention_backward assumes name is logsumexp | |
dropout_p: float, | |
scale: Optional[float] = None, | |
is_causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
deterministic: bool = False, | |
rng_state: Optional[torch.Tensor] = None, | |
_permute_outputs: bool = True, | |
): | |
dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) | |
grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D] | |
dq, dk, dv, softmax_d = _flash_attn_backward( | |
grad_out, | |
query, | |
key, | |
value, | |
out, | |
logsumexp, | |
dq, | |
dk, | |
dv, | |
dropout_p, | |
scale, | |
is_causal, | |
window_size, | |
softcap, | |
alibi_slopes, | |
deterministic, | |
rng_state, | |
) | |
# Head dimension may have been padded | |
dq = dq[..., : grad_out.shape[-1]] | |
dk = dk[..., : grad_out.shape[-1]] | |
dv = dv[..., : grad_out.shape[-1]] | |
if _permute_outputs: | |
dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) # [B, S, N, D] -> [B, N, S, D] | |
return dq, dk, dv | |
# ===== Attention provider ===== | |
class AttentionProvider(str, Enum): | |
# EAGER = "eager" | |
# `flash-attn` | |
FLASH = "flash" | |
FLASH_VARLEN = "flash_varlen" | |
# PyTorch native | |
FLEX = "flex" | |
NATIVE = "native" | |
_NATIVE_CUDNN = "_native_cudnn" | |
_NATIVE_EFFICIENT = "_native_efficient" | |
_NATIVE_FLASH = "_native_flash" | |
_NATIVE_MATH = "_native_math" | |
# `sageattention` | |
SAGE = "sage" | |
SAGE_VARLEN = "sage_varlen" | |
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" | |
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" | |
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" | |
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" | |
# TODO: let's not add support for Sparge Attention now because it requires tuning per model | |
# We can look into supporting something "autotune"-ing in the future | |
# SPARGE = "sparge" | |
# `xformers` | |
XFORMERS = "xformers" | |
class _AttentionProviderRegistry: | |
_providers = {} | |
_constraints = {} | |
_supports_cp = {} | |
_supported_arg_names = {} | |
_active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) | |
_checks_enabled = FINETRAINERS_ATTN_CHECKS | |
# Context parallel attributes | |
_mesh: torch.distributed.device_mesh.DeviceMesh = None | |
_convert_to_fp32: bool = None | |
_rotate_method: Literal["allgather", "alltoall"] = None | |
def register( | |
cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False | |
): | |
logger.debug(f"Registering attention provider: {provider}") | |
def decorator(func): | |
cls._providers[provider] = func | |
cls._constraints[provider] = constraints or [] | |
cls._supports_cp[provider] = supports_cp | |
cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) | |
return func | |
return decorator | |
def get_active_provider(cls): | |
return cls._active_provider, cls._providers[cls._active_provider] | |
def list_providers(cls): | |
return list(cls._providers.keys()) | |
def supports_context_parallel(cls, provider: AttentionProvider): | |
if provider not in cls._providers: | |
raise ValueError(f"Provider {provider} is not registered.") | |
return cls._supports_cp.get(provider, False) | |
def context_parallel_enabled(cls): | |
return cls._mesh is not None | |
def _set_context_parallel( | |
cls, | |
mesh: torch.distributed.device_mesh.DeviceMesh = None, | |
convert_to_fp32: bool = None, | |
rotate_method: str = None, | |
*, | |
reset: bool = False, | |
): | |
if reset: | |
mesh = convert_to_fp32 = rotate_method = None | |
cls._mesh = mesh | |
cls._convert_to_fp32 = convert_to_fp32 | |
cls._rotate_method = rotate_method | |
def _raise_cp_error_if_mesh_not_set(cls): | |
if cls._mesh is None: | |
raise ValueError( | |
"`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." | |
) | |
def attention_provider( | |
provider: AttentionProvider = AttentionProvider.NATIVE, | |
*, | |
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | |
convert_to_fp32: bool = True, | |
rotate_method: str = "allgather", | |
): | |
"""Context manager to set the active attention provider and possibly enable context parallelism.""" | |
if provider not in _AttentionProviderRegistry._providers: | |
raise ValueError(f"Provider {provider} is not registered.") | |
if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): | |
raise ValueError(f"Provider {provider} does not support context parallelism.") | |
old_provider = _AttentionProviderRegistry._active_provider | |
_AttentionProviderRegistry._active_provider = provider | |
_AttentionProviderRegistry._mesh = mesh | |
_AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 | |
_AttentionProviderRegistry._rotate_method = rotate_method | |
if mesh is not None: | |
_convert_to_f32 = _cp_options.convert_to_f32 | |
_enable_load_balance = _cp_options.enable_load_balance | |
_rotate_method = _cp_options.rotate_method | |
try: | |
yield | |
finally: | |
_AttentionProviderRegistry._active_provider = old_provider | |
_AttentionProviderRegistry._mesh = None | |
_AttentionProviderRegistry._convert_to_fp32 = None | |
_AttentionProviderRegistry._rotate_method = None | |
if mesh is not None: | |
_cp_options.convert_to_f32 = _convert_to_f32 | |
_cp_options.enable_load_balance = _enable_load_balance | |
_cp_options.rotate_method = _rotate_method | |
def attention_dispatch( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
enable_gqa: bool = False, | |
attention_kwargs: Optional[Dict[str, Any]] = None, | |
) -> torch.Tensor: | |
attention_kwargs = attention_kwargs or {} | |
provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() | |
kwargs = { | |
"query": query, | |
"key": key, | |
"value": value, | |
"attn_mask": attn_mask, | |
"dropout_p": dropout_p, | |
"is_causal": is_causal, | |
"scale": scale, | |
"enable_gqa": enable_gqa, | |
**attention_kwargs, | |
} | |
if _AttentionProviderRegistry._checks_enabled: | |
removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) | |
if removed_kwargs: | |
log_freq = 512 | |
msg = ( | |
f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " | |
f"message will be logged every {log_freq} calls." | |
) | |
logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) | |
for check in _AttentionProviderRegistry._constraints.get(provider_name): | |
check(**kwargs) | |
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} | |
if _AttentionProviderRegistry.context_parallel_enabled(): | |
_set_context_parallel_options(**kwargs) | |
return provider_fn(**kwargs) | |
# ===== Helper functions ===== | |
# @torch.compiler.assume_constant_result | |
def _set_context_parallel_options(is_causal: bool, **kwargs): | |
_cp_options.enable_load_balance = is_causal | |
_cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 | |
set_rotate_method(_AttentionProviderRegistry._rotate_method) | |
def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: | |
if attn_mask is not None: | |
raise ValueError("Attention mask must be None for this provider.") | |
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: | |
if attn_mask is not None and is_causal: | |
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") | |
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
if query.device != key.device or query.device != value.device: | |
raise ValueError("Query, key, and value must be on the same device.") | |
if query.dtype != key.dtype or query.dtype != value.dtype: | |
raise ValueError("Query, key, and value must have the same dtype.") | |
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
_check_device(query, key, value) | |
if query.device.type != "cuda": | |
raise ValueError("Query, key, and value must be on a CUDA device.") | |
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: | |
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
_check_device_cuda(query, key, value) | |
if torch.cuda.get_device_capability(query.device) < (major, minor): | |
raise ValueError( | |
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." | |
) | |
return check_device_cuda | |
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
if query.dtype != key.dtype: | |
raise ValueError("Query and key must have the same dtype.") | |
if query.dtype != value.dtype: | |
raise ValueError("Query and value must have the same dtype.") | |
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
_check_qkv_dtype_match(query, key, value) | |
if query.dtype not in (torch.bfloat16, torch.float16): | |
raise ValueError("Query, key, and value must be either bfloat16 or float16.") | |
def _check_shape( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
**kwargs, | |
) -> None: | |
if query.shape[-1] != key.shape[-1]: | |
raise ValueError("Query and key must have the same last dimension.") | |
if query.shape[-2] != value.shape[-2]: | |
raise ValueError("Query and value must have the same second to last dimension.") | |
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: | |
raise ValueError("Attention mask must match the key's second to last dimension.") | |
def _prepare_for_flash_attn_or_sage_varlen( | |
batch_size: int, | |
seq_len_q: int, | |
seq_len_kv: int, | |
attn_mask: Optional[torch.Tensor] = None, | |
device: Optional[torch.device] = None, | |
) -> None: | |
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) | |
if attn_mask is None: | |
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) | |
else: | |
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) | |
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) | |
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) | |
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) | |
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) | |
max_seqlen_q = seqlens_q.max().item() | |
max_seqlen_k = seqlens_k.max().item() | |
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) | |
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: | |
""" | |
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in | |
FlashAttention/Sage varlen. | |
Supports 1D to 4D shapes and common broadcasting patterns. | |
""" | |
if attn_mask.dtype != torch.bool: | |
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") | |
if attn_mask.ndim == 1: | |
# [seq_len_k] -> broadcast across batch | |
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) | |
elif attn_mask.ndim == 2: | |
# [batch_size, seq_len_k]. Maybe broadcast across batch | |
if attn_mask.size(0) not in [1, batch_size]: | |
raise ValueError( | |
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." | |
) | |
attn_mask = attn_mask.expand(batch_size, seq_len_k) | |
elif attn_mask.ndim == 3: | |
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension | |
if attn_mask.size(0) not in [1, batch_size]: | |
raise ValueError( | |
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." | |
) | |
attn_mask = attn_mask.any(dim=1) | |
attn_mask = attn_mask.expand(batch_size, seq_len_k) | |
elif attn_mask.ndim == 4: | |
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions | |
if attn_mask.size(0) not in [1, batch_size]: | |
raise ValueError( | |
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." | |
) | |
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] | |
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] | |
else: | |
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") | |
if attn_mask.shape != (batch_size, seq_len_k): | |
raise ValueError( | |
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" | |
) | |
return attn_mask | |
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): | |
return q_idx >= kv_idx | |
# ===== Attention provider implementations ===== | |
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 | |
class _flash_attn_flash_attention(torch.autograd.Function): | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
dropout_p: float = 0.0, | |
softmax_scale: Optional[float] = None, | |
causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
deterministic: bool = False, | |
return_softmax: bool = False, | |
): | |
if softmax_scale is None: | |
softmax_scale = q.shape[-1] ** (-0.5) | |
ctx.dropout_p = dropout_p | |
ctx.softmax_scale = softmax_scale | |
ctx.causal = causal | |
ctx.window_size = window_size | |
ctx.softcap = softcap | |
ctx.alibi_slopes = alibi_slopes | |
ctx.deterministic = deterministic | |
out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( | |
query=q, | |
key=k, | |
value=v, | |
dropout_p=dropout_p, | |
scale=softmax_scale, | |
is_causal=causal, | |
window_size=window_size, | |
softcap=softcap, | |
alibi_slopes=alibi_slopes, | |
return_softmax=return_softmax, | |
) | |
ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) | |
return (out, lse) if return_softmax else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
q, k, v, out, lse, rng_state = ctx.saved_tensors | |
grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( | |
grad_out=grad_out, | |
query=q, | |
key=k, | |
value=v, | |
out=out, | |
logsumexp=lse, | |
dropout_p=ctx.dropout_p, | |
scale=ctx.softmax_scale, | |
is_causal=ctx.causal, | |
window_size=ctx.window_size, | |
softcap=ctx.softcap, | |
alibi_slopes=ctx.alibi_slopes, | |
deterministic=ctx.deterministic, | |
rng_state=rng_state, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | |
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 | |
class _native_ring_flash_attn_flash_attention(torch.autograd.Function): | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
dropout_p: float = 0.0, | |
softmax_scale: Optional[float] = None, | |
causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
deterministic: bool = False, | |
return_softmax: bool = False, | |
): | |
if softmax_scale is None: | |
softmax_scale = q.shape[-1] ** (-0.5) | |
# For ring flash attention using the flash-attn repo, we want the LSE but flash-attn only supports it if dropout_p > 0 | |
dropout_p = dropout_p if dropout_p > 0 else 1e-30 | |
ctx.dropout_p = dropout_p | |
ctx.softmax_scale = softmax_scale | |
ctx.causal = causal | |
ctx.window_size = window_size | |
ctx.softcap = softcap | |
ctx.alibi_slopes = alibi_slopes | |
ctx.deterministic = deterministic | |
out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=_finetrainers_flash_attn_forward, | |
query=q, | |
key=k, | |
value=v, | |
dropout_p=dropout_p, | |
scale=softmax_scale, | |
is_causal=causal, | |
window_size=window_size, | |
softcap=softcap, | |
alibi_slopes=alibi_slopes, | |
return_softmax=True, | |
) | |
ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) | |
return (out, lse) if return_softmax else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
q, k, v, out, lse, rng_state = ctx.saved_tensors | |
lse = lse.permute(0, 2, 1).contiguous() # [B, N, S] -> [B, S, N] | |
grad_query, grad_key, grad_value = _templated_ring_attention_backward( | |
mesh=_AttentionProviderRegistry._mesh, | |
# This needs to be 1 because q, k, v, out_padded returned from forward are BSND instead of BNSD | |
# The grad_out permutation is handled in _finetrainers_flash_attn_backward, and the outputs from that are expected to have | |
# shape BSND instead of BNSD (requirement of _templated_ring_attention_backward), so we need to set seq_dim=1 and permute the | |
# returned outputs | |
seq_dim=1, | |
op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), | |
grad_out=grad_out, | |
grad_out_name="grad_out", | |
query=q, | |
key=k, | |
value=v, | |
out=out, | |
logsumexp=lse, | |
dropout_p=ctx.dropout_p, | |
scale=ctx.softmax_scale, | |
is_causal=ctx.causal, | |
window_size=ctx.window_size, | |
softcap=ctx.softcap, | |
alibi_slopes=ctx.alibi_slopes, | |
deterministic=ctx.deterministic, | |
rng_state=rng_state, | |
) | |
grad_query, grad_key, grad_value = ( | |
x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | |
def flash_attn_flash_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
dropout_p: float = 0.0, | |
scale: Optional[float] = None, | |
is_causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
deterministic: bool = False, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
dispatch_fn = ( | |
_native_ring_flash_attn_flash_attention | |
if _AttentionProviderRegistry.context_parallel_enabled() | |
else _flash_attn_flash_attention | |
) | |
return dispatch_fn.apply( | |
query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse | |
) | |
def _flash_varlen_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
cu_seqlens_q: Optional[torch.Tensor] = None, | |
cu_seqlens_k: Optional[torch.Tensor] = None, | |
max_seqlen_q: Optional[int] = None, | |
max_seqlen_k: Optional[int] = None, | |
dropout_p: float = 0.0, | |
scale: Optional[float] = None, | |
is_causal: bool = False, | |
window_size: Tuple[int, int] = (-1, -1), | |
softcap: float = 0.0, | |
alibi_slopes: Optional[torch.Tensor] = None, | |
deterministic: bool = False, | |
return_attn_probs: bool = False, | |
attn_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
batch_size, _, seq_len_q, _ = query.shape | |
_, _, seq_len_kv, _ = key.shape | |
if attn_mask is not None: | |
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | |
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): | |
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | |
_prepare_for_flash_attn_or_sage_varlen( | |
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | |
) | |
) | |
else: | |
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) | |
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) | |
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) | |
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
key_valid, value_valid = [], [] | |
for b in range(batch_size): | |
valid_len = seqlens_k[b] | |
key_valid.append(key[b, :valid_len]) | |
value_valid.append(value[b, :valid_len]) | |
query_packed = query.flatten(0, 1) | |
key_packed = torch.cat(key_valid, dim=0) | |
value_packed = torch.cat(value_valid, dim=0) | |
if _AttentionProviderRegistry.context_parallel_enabled(): | |
return_attn_probs = True | |
out = flash_attn_varlen_func( | |
q=query_packed, | |
k=key_packed, | |
v=value_packed, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_q, | |
max_seqlen_k=max_seqlen_k, | |
dropout_p=dropout_p, | |
softmax_scale=scale, | |
causal=is_causal, | |
window_size=window_size, | |
softcap=softcap, | |
alibi_slopes=alibi_slopes, | |
deterministic=deterministic, | |
return_attn_probs=return_attn_probs, | |
) | |
rest = None | |
if return_attn_probs: | |
out, *rest = out | |
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() | |
if return_attn_probs: | |
return out, *rest[:1] | |
return out | |
def _native_flex_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
enable_gqa: bool = False, | |
return_lse: bool = False, | |
kernel_options: Optional[Dict[str, Any]] = None, | |
) -> torch.Tensor: | |
# TODO: should we LRU cache the block mask creation? | |
score_mod = None | |
block_mask = None | |
batch_size, num_heads, seq_len_q, _ = query.shape | |
_, _, seq_len_kv, _ = key.shape | |
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): | |
block_mask = attn_mask | |
elif is_causal: | |
block_mask = flex_attention.create_block_mask( | |
_flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device | |
) | |
elif torch.is_tensor(attn_mask): | |
if attn_mask.ndim == 2: | |
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) | |
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) | |
if attn_mask.dtype == torch.bool: | |
# TODO: this probably does not work but verify! | |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx): | |
return attn_mask[batch_idx, head_idx, q_idx, kv_idx] | |
block_mask = flex_attention.create_block_mask( | |
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device | |
) | |
else: | |
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] | |
else: | |
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") | |
return flex_attention.flex_attention( | |
query=query, | |
key=key, | |
value=value, | |
score_mod=score_mod, | |
block_mask=block_mask, | |
scale=scale, | |
enable_gqa=enable_gqa, | |
return_lse=return_lse, | |
kernel_options=None, | |
) | |
def _native_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
enable_gqa: bool = False, | |
) -> torch.Tensor: | |
return native_sdpa( | |
query=query, | |
key=key, | |
value=value, | |
attn_mask=attn_mask, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
enable_gqa=enable_gqa, | |
) | |
class _native_cudnn_attention(torch.autograd.Function): | |
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 | |
# forward declaration: | |
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |
# backward declaration: | |
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
ctx.attn_mask = attn_mask | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_mask, | |
compute_log_sumexp=True, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
return_debug_mask=False, | |
scale=scale, | |
) | |
) | |
ctx.max_q = max_q | |
ctx.max_k = max_k | |
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( | |
grad_out=grad_out, | |
query=query, | |
key=key, | |
value=value, | |
out=out, | |
logsumexp=lse, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
attn_bias=ctx.attn_mask, | |
cum_seq_q=cum_seq_q, | |
cum_seq_k=cum_seq_k, | |
max_q=ctx.max_q, | |
max_k=ctx.max_k, | |
dropout_p=ctx.dropout_p, | |
is_causal=ctx.is_causal, | |
scale=ctx.scale, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None | |
class _native_ring_native_cudnn_attention(torch.autograd.Function): | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
ctx.attn_mask = attn_mask | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
_templated_ring_attention( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=torch.ops.aten._scaled_dot_product_cudnn_attention, | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_mask, | |
compute_log_sumexp=True, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
return_debug_mask=False, | |
scale=scale, | |
) | |
) | |
ctx.max_q = max_q | |
ctx.max_k = max_k | |
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
grad_query, grad_key, grad_value = _templated_ring_attention_backward( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, | |
grad_out=grad_out, | |
grad_out_name="grad_out", | |
query=query, | |
key=key, | |
value=value, | |
out=out, | |
logsumexp=lse, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
attn_bias=ctx.attn_mask, | |
cum_seq_q=cum_seq_q, | |
cum_seq_k=cum_seq_k, | |
max_q=ctx.max_q, | |
max_k=ctx.max_k, | |
dropout_p=ctx.dropout_p, | |
is_causal=ctx.is_causal, | |
scale=ctx.scale, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None | |
def native_cudnn_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
dispatch_fn = ( | |
_native_ring_native_cudnn_attention | |
if _AttentionProviderRegistry.context_parallel_enabled() | |
else _native_cudnn_attention | |
) | |
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) | |
class _native_efficient_attention(torch.autograd.Function): | |
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 | |
# forward declaration: | |
# aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) | |
# backward declaration: | |
# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
ctx.attn_mask = attn_mask | |
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_mask, | |
compute_log_sumexp=True, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors | |
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
grad_query, grad_key, grad_value, grad_attn_bias = ( | |
_finetrainers_scaled_dot_product_efficient_attention_backward( | |
grad_out_=grad_out, | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=ctx.attn_mask, | |
out=out, | |
logsumexp=lse, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
dropout_p=ctx.dropout_p, | |
grad_input_mask=[True, True, True, False], | |
is_causal=ctx.is_causal, | |
scale=ctx.scale, | |
) | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None | |
class _native_ring_native_efficient_attention(torch.autograd.Function): | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
ctx.attn_mask = attn_mask | |
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
out, lse, philox_seed, philox_offset = _templated_ring_attention( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=_finetrainers_scaled_dot_product_efficient_attention_forward, | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=attn_mask, | |
compute_log_sumexp=True, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors | |
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=_finetrainers_scaled_dot_product_efficient_attention_backward, | |
grad_out=grad_out, | |
grad_out_name="grad_out_", | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=ctx.attn_mask, | |
out=out, | |
logsumexp=lse, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
dropout_p=ctx.dropout_p, | |
grad_input_mask=[True, True, True, False], | |
is_causal=ctx.is_causal, | |
scale=ctx.scale, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None, None | |
def native_efficient_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
) -> torch.Tensor: | |
dispatch_fn = ( | |
_native_ring_native_efficient_attention | |
if _AttentionProviderRegistry.context_parallel_enabled() | |
else _native_efficient_attention | |
) | |
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) | |
class _native_flash_attention(torch.autograd.Function): | |
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910 | |
# forward declaration: | |
# aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |
# backward declaration: | |
# aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_flash_attention( | |
query=query, | |
key=key, | |
value=value, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
return_debug_mask=False, | |
scale=scale, | |
) | |
) | |
ctx.max_q = max_q | |
ctx.max_k = max_k | |
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( | |
grad_out=grad_out, | |
query=query, | |
key=key, | |
value=value, | |
out=out, | |
logsumexp=lse, | |
cum_seq_q=cum_seq_q, | |
cum_seq_k=cum_seq_k, | |
max_q=ctx.max_q, | |
max_k=ctx.max_k, | |
dropout_p=ctx.dropout_p, | |
is_causal=ctx.is_causal, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
scale=ctx.scale, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None | |
class _native_ring_native_flash_attention(torch.autograd.Function): | |
def forward( | |
ctx: torch.autograd.function.FunctionCtx, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
ctx.dropout_p = dropout_p | |
ctx.is_causal = is_causal | |
ctx.scale = scale | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
_templated_ring_attention( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=torch.ops.aten._scaled_dot_product_flash_attention, | |
query=query, | |
key=key, | |
value=value, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
) | |
ctx.max_q = max_q | |
ctx.max_k = max_k | |
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
return (out, lse) if return_lse else out | |
def backward( | |
ctx: torch.autograd.function.FunctionCtx, | |
grad_out: torch.Tensor, | |
*args: torch.Tensor, | |
): | |
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( | |
mesh=_AttentionProviderRegistry._mesh, | |
seq_dim=2, | |
op=torch.ops.aten._scaled_dot_product_flash_attention_backward, | |
grad_out=grad_out, | |
grad_out_name="grad_out", | |
query=query, | |
key=key, | |
value=value, | |
out=out, | |
logsumexp=lse, | |
dropout_p=ctx.dropout_p, | |
is_causal=ctx.is_causal, | |
scale=ctx.scale, | |
cum_seq_q=cum_seq_q, | |
cum_seq_k=cum_seq_k, | |
max_q=ctx.max_q, | |
max_k=ctx.max_k, | |
philox_seed=philox_seed, | |
philox_offset=philox_offset, | |
) | |
return grad_query, grad_key, grad_value, None, None, None, None | |
def native_flash_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
dispatch_fn = ( | |
_native_ring_native_flash_attention | |
if _AttentionProviderRegistry.context_parallel_enabled() | |
else _native_flash_attention | |
) | |
return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) | |
# class _native_math_attention(torch.autograd.Function): | |
# # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901 | |
# # forward declaration: | |
# # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) | |
# # backward declaration: | |
# # does not exist | |
# @staticmethod | |
# def forward( | |
# ctx: torch.autograd.function.FunctionCtx, | |
# query: torch.Tensor, | |
# key: torch.Tensor, | |
# value: torch.Tensor, | |
# attn_mask: Optional[torch.Tensor] = None, | |
# dropout_p: float = 0.0, | |
# is_causal: bool = False, | |
# dropout_mask: Optional[torch.Tensor] = None, | |
# scale: Optional[float] = None, | |
# enable_gqa: bool = False, | |
# return_scores: bool = False, | |
# ): | |
# ctx.dropout_p = dropout_p | |
# ctx.is_causal = is_causal | |
# ctx.scale = scale | |
# ctx.enable_gqa = enable_gqa | |
# print(f"query.shape: {query.shape}") | |
# with torch.enable_grad(): | |
# out, scores = torch.ops.aten._scaled_dot_product_attention_math( | |
# query=query, | |
# key=key, | |
# value=value, | |
# attn_mask=attn_mask, | |
# dropout_p=dropout_p, | |
# is_causal=is_causal, | |
# dropout_mask=dropout_mask, | |
# scale=scale, | |
# enable_gqa=enable_gqa, | |
# ) | |
# ctx.save_for_backward(query, key, value, out) | |
# return (out, scores) if return_scores else out | |
# @staticmethod | |
# def backward( | |
# ctx: torch.autograd.function.FunctionCtx, | |
# grad_out: torch.Tensor, | |
# ): | |
# raise NotImplementedError("Backward pass for native math attention is not implemented.") | |
def native_math_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
enable_gqa: bool = False, | |
) -> torch.Tensor: | |
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): | |
return native_sdpa( | |
query=query, | |
key=key, | |
value=value, | |
attn_mask=attn_mask, | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
enable_gqa=enable_gqa, | |
) | |
def _sage_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
if _AttentionProviderRegistry.context_parallel_enabled(): | |
return_lse = True | |
kwargs = { | |
"q": query, | |
"k": key, | |
"v": value, | |
"tensor_layout": "HND", | |
"is_causal": is_causal, | |
"sm_scale": scale, | |
"return_lse": return_lse, | |
} | |
out = sageattn(**kwargs) | |
rest = None | |
if return_lse: | |
out, *rest = out | |
if return_lse: | |
return out, *rest[:1] | |
return out | |
def _sage_varlen_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
cu_seqlens_q: Optional[torch.Tensor] = None, | |
cu_seqlens_k: Optional[torch.Tensor] = None, | |
max_seqlen_q: Optional[int] = None, | |
max_seqlen_k: Optional[int] = None, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
smooth_k: bool = True, | |
attn_mask: Optional[torch.Tensor] = None, | |
enable_gqa: bool = False, | |
) -> torch.Tensor: | |
batch_size, _, seq_len_q, _ = query.shape | |
_, _, seq_len_kv, _ = key.shape | |
if attn_mask is not None: | |
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | |
if enable_gqa: | |
# TODO | |
pass | |
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): | |
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | |
_prepare_for_flash_attn_or_sage_varlen( | |
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | |
) | |
) | |
else: | |
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) | |
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) | |
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) | |
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
key_valid, value_valid = [], [] | |
for b in range(batch_size): | |
valid_len = seqlens_k[b] | |
key_valid.append(key[b, :valid_len]) | |
value_valid.append(value[b, :valid_len]) | |
query_packed = query.flatten(0, 1) | |
key_packed = torch.cat(key_valid, dim=0) | |
value_packed = torch.cat(value_valid, dim=0) | |
out = sageattn_varlen( | |
q=query_packed, | |
k=key_packed, | |
v=value_packed, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_q, | |
max_seqlen_k=max_seqlen_k, | |
is_causal=is_causal, | |
sm_scale=scale, | |
smooth_k=smooth_k, | |
) | |
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() | |
return out | |
def _sage_qk_int8_pv_fp8_cuda_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
smooth_k: bool = True, | |
smooth_v: bool = False, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
return sageattn_qk_int8_pv_fp8_cuda( | |
q=query, | |
k=key, | |
v=value, | |
tensor_layout="HND", | |
is_causal=is_causal, | |
qk_quant_gran=qk_quant_gran, | |
sm_scale=scale, | |
pv_accum_dtype=pv_accum_dtype, | |
smooth_k=smooth_k, | |
smooth_v=smooth_v, | |
return_lse=return_lse, | |
) | |
def _sage_qk_int8_pv_fp8_cuda_sm90_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
smooth_k: bool = True, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
return sageattn_qk_int8_pv_fp8_cuda_sm90( | |
q=query, | |
k=key, | |
v=value, | |
tensor_layout="HND", | |
is_causal=is_causal, | |
qk_quant_gran=qk_quant_gran, | |
sm_scale=scale, | |
pv_accum_dtype=pv_accum_dtype, | |
smooth_k=smooth_k, | |
return_lse=return_lse, | |
) | |
def _sage_qk_int8_pv_fp16_cuda_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
smooth_k: bool = True, | |
smooth_v: bool = False, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
return sageattn_qk_int8_pv_fp16_cuda( | |
q=query, | |
k=key, | |
v=value, | |
tensor_layout="HND", | |
is_causal=is_causal, | |
qk_quant_gran=qk_quant_gran, | |
sm_scale=scale, | |
pv_accum_dtype=pv_accum_dtype, | |
smooth_k=smooth_k, | |
smooth_v=smooth_v, | |
return_lse=return_lse, | |
) | |
def _sage_qk_int8_pv_fp16_triton_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", | |
smooth_k: bool = True, | |
return_lse: bool = False, | |
) -> torch.Tensor: | |
return sageattn_qk_int8_pv_fp16_triton( | |
q=query, | |
k=key, | |
v=value, | |
tensor_layout="HND", | |
quantization_backend=quantization_backend, | |
is_causal=is_causal, | |
sm_scale=scale, | |
smooth_k=smooth_k, | |
return_lse=return_lse, | |
) | |
def _xformers_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
enable_gqa: bool = False, | |
) -> torch.Tensor: | |
batch_size, num_heads_q, seq_len_q, _ = query.shape | |
_, num_heads_kv, seq_len_kv, _ = key.shape | |
# TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns | |
if is_causal: | |
attn_mask = xops.LowerTriangularMask() | |
elif attn_mask is not None: | |
if attn_mask.ndim == 2: | |
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) | |
elif attn_mask.ndim != 4: | |
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") | |
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) | |
# QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers | |
# query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) | |
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
if enable_gqa: | |
if num_heads_q % num_heads_kv != 0: | |
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") | |
num_heads_per_group = num_heads_q // num_heads_kv | |
query = query.unflatten(2, (num_heads_kv, -1)) | |
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) | |
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) | |
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) | |
if enable_gqa: | |
out = out.flatten(2, 3) | |
out = out.permute(0, 2, 1, 3) # .contiguous() | |
return out | |