import contextlib import functools import inspect from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch # Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take # control for dispatching to different attention providers, we need to import the original function # to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`. import torch.autograd from diffusers.utils.import_utils import OptionalDependencyNotAvailable from torch.nn.functional import scaled_dot_product_attention as native_sdpa from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER from finetrainers.logging import get_logger from finetrainers.utils.import_utils import ( is_flash_attn_available, is_flash_attn_version, is_sageattention_available, is_sageattention_version, is_torch_version, is_xformers_available, is_xformers_version, ) if is_flash_attn_available(): if is_flash_attn_version("<", "2.6.3"): raise OptionalDependencyNotAvailable( "The `flash-attn` library version is too old. Please update it to at least 2.6.3." ) from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward else: flash_attn_func = None flash_attn_varlen_func = None _flash_attn_forward = None _flash_attn_backward = None if is_sageattention_available(): if is_sageattention_version("<", "2.1.1"): raise OptionalDependencyNotAvailable( "The `sageattention` library version is too old. Please update it to at least 2.1.1." ) from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, sageattn_qk_int8_pv_fp8_cuda_sm90, sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_varlen, ) else: sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None sageattn_qk_int8_pv_fp8_cuda = None sageattn_qk_int8_pv_fp8_cuda_sm90 = None sageattn_varlen = None if is_torch_version(">=", "2.5.0"): import torch.nn.attention.flex_attention as flex_attention if is_torch_version(">=", "2.6.0"): from torch.distributed.tensor.experimental._attention import ( _AttentionOp, _cp_options, _templated_ring_attention, _templated_ring_attention_backward, set_rotate_method, ) else: _cp_options = None _templated_ring_attention = None set_rotate_method = None class _AttentionOp: def __init__(self, *args, **kwargs): raise OptionalDependencyNotAvailable( "The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." ) if is_xformers_available(): if is_xformers_version("<", "0.0.29"): raise OptionalDependencyNotAvailable( "The `xformers` library version is too old. Please update it to at least 0.0.29." ) import xformers.ops as xops else: xops = None logger = get_logger() _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] # ===== Custom operator implementations/wrappers ===== def _finetrainers_scaled_dot_product_efficient_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, compute_log_sumexp: bool = False, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 # See: https://github.com/pytorch/pytorch/issues/152942 seqlen_q = query.shape[-2] out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( query=query, key=key, value=value, attn_bias=attn_bias, compute_log_sumexp=compute_log_sumexp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, ) # LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch # ring attention does not error out with shape mismatch if compute_log_sumexp: assert lse.ndim == 3 lse = lse[:, :, :seqlen_q] # .contiguous() return out, lse, philox_seed, philox_offset # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) def _finetrainers_scaled_dot_product_efficient_attention_backward( grad_out_: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_bias: torch.Tensor, out: torch.Tensor, logsumexp: torch.Tensor, philox_seed: torch.Tensor, philox_offset: torch.Tensor, dropout_p: float, grad_input_mask: List[bool], is_causal: bool = False, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert len(grad_input_mask) == 4 # https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113 kAlignLSE = 32 logsumexp = torch.nn.functional.pad( logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") ) grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( grad_out_=grad_out_, query=query, key=key, value=value, attn_bias=attn_bias, out=out, logsumexp=logsumexp, philox_seed=philox_seed, philox_offset=philox_offset, dropout_p=dropout_p, grad_input_mask=grad_input_mask, is_causal=is_causal, scale=scale, ) return grad_query, grad_key, grad_value, grad_attn_bias # This function wraps the actual _flash_attn_forward call to return LSE at index 1 to be compatible with pytorch's native ring attention def _finetrainers_flash_attn_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, return_softmax: bool = False, ): query, key, value = ( x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) ) # [B, N, S, D] -> [B, S, N, D] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax ) out = out.permute(0, 2, 1, 3).contiguous() # [B, S, N, D] -> [B, N, S, D] return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state # This function wraps the actual _flash_attn_backward call as the counterpart of the _finetrainers_flash_attn_forward function def _finetrainers_flash_attn_backward( grad_out: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, logsumexp: torch.Tensor, # Needs a different names than the one used in flash-attn because _templated_ring_attention_backward assumes name is logsumexp dropout_p: float, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, rng_state: Optional[torch.Tensor] = None, _permute_outputs: bool = True, ): dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D] dq, dk, dv, softmax_d = _flash_attn_backward( grad_out, query, key, value, out, logsumexp, dq, dk, dv, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, rng_state, ) # Head dimension may have been padded dq = dq[..., : grad_out.shape[-1]] dk = dk[..., : grad_out.shape[-1]] dv = dv[..., : grad_out.shape[-1]] if _permute_outputs: dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) # [B, S, N, D] -> [B, N, S, D] return dq, dk, dv # ===== Attention provider ===== class AttentionProvider(str, Enum): # EAGER = "eager" # `flash-attn` FLASH = "flash" FLASH_VARLEN = "flash_varlen" # PyTorch native FLEX = "flex" NATIVE = "native" _NATIVE_CUDNN = "_native_cudnn" _NATIVE_EFFICIENT = "_native_efficient" _NATIVE_FLASH = "_native_flash" _NATIVE_MATH = "_native_math" # `sageattention` SAGE = "sage" SAGE_VARLEN = "sage_varlen" _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" # TODO: let's not add support for Sparge Attention now because it requires tuning per model # We can look into supporting something "autotune"-ing in the future # SPARGE = "sparge" # `xformers` XFORMERS = "xformers" class _AttentionProviderRegistry: _providers = {} _constraints = {} _supports_cp = {} _supported_arg_names = {} _active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) _checks_enabled = FINETRAINERS_ATTN_CHECKS # Context parallel attributes _mesh: torch.distributed.device_mesh.DeviceMesh = None _convert_to_fp32: bool = None _rotate_method: Literal["allgather", "alltoall"] = None @classmethod def register( cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False ): logger.debug(f"Registering attention provider: {provider}") def decorator(func): cls._providers[provider] = func cls._constraints[provider] = constraints or [] cls._supports_cp[provider] = supports_cp cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) return func return decorator @classmethod def get_active_provider(cls): return cls._active_provider, cls._providers[cls._active_provider] @classmethod def list_providers(cls): return list(cls._providers.keys()) @classmethod def supports_context_parallel(cls, provider: AttentionProvider): if provider not in cls._providers: raise ValueError(f"Provider {provider} is not registered.") return cls._supports_cp.get(provider, False) @classmethod def context_parallel_enabled(cls): return cls._mesh is not None @classmethod def _set_context_parallel( cls, mesh: torch.distributed.device_mesh.DeviceMesh = None, convert_to_fp32: bool = None, rotate_method: str = None, *, reset: bool = False, ): if reset: mesh = convert_to_fp32 = rotate_method = None cls._mesh = mesh cls._convert_to_fp32 = convert_to_fp32 cls._rotate_method = rotate_method @classmethod def _raise_cp_error_if_mesh_not_set(cls): if cls._mesh is None: raise ValueError( "`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." ) @contextlib.contextmanager def attention_provider( provider: AttentionProvider = AttentionProvider.NATIVE, *, mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, convert_to_fp32: bool = True, rotate_method: str = "allgather", ): """Context manager to set the active attention provider and possibly enable context parallelism.""" if provider not in _AttentionProviderRegistry._providers: raise ValueError(f"Provider {provider} is not registered.") if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): raise ValueError(f"Provider {provider} does not support context parallelism.") old_provider = _AttentionProviderRegistry._active_provider _AttentionProviderRegistry._active_provider = provider _AttentionProviderRegistry._mesh = mesh _AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 _AttentionProviderRegistry._rotate_method = rotate_method if mesh is not None: _convert_to_f32 = _cp_options.convert_to_f32 _enable_load_balance = _cp_options.enable_load_balance _rotate_method = _cp_options.rotate_method try: yield finally: _AttentionProviderRegistry._active_provider = old_provider _AttentionProviderRegistry._mesh = None _AttentionProviderRegistry._convert_to_fp32 = None _AttentionProviderRegistry._rotate_method = None if mesh is not None: _cp_options.convert_to_f32 = _convert_to_f32 _cp_options.enable_load_balance = _enable_load_balance _cp_options.rotate_method = _rotate_method def attention_dispatch( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() kwargs = { "query": query, "key": key, "value": value, "attn_mask": attn_mask, "dropout_p": dropout_p, "is_causal": is_causal, "scale": scale, "enable_gqa": enable_gqa, **attention_kwargs, } if _AttentionProviderRegistry._checks_enabled: removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) if removed_kwargs: log_freq = 512 msg = ( f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " f"message will be logged every {log_freq} calls." ) logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) for check in _AttentionProviderRegistry._constraints.get(provider_name): check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} if _AttentionProviderRegistry.context_parallel_enabled(): _set_context_parallel_options(**kwargs) return provider_fn(**kwargs) # ===== Helper functions ===== # @torch.compiler.assume_constant_result def _set_context_parallel_options(is_causal: bool, **kwargs): _cp_options.enable_load_balance = is_causal _cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 set_rotate_method(_AttentionProviderRegistry._rotate_method) def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: if attn_mask is not None: raise ValueError("Attention mask must be None for this provider.") def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: if attn_mask is not None and is_causal: raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: if query.device != key.device or query.device != value.device: raise ValueError("Query, key, and value must be on the same device.") if query.dtype != key.dtype or query.dtype != value.dtype: raise ValueError("Query, key, and value must have the same dtype.") def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_device(query, key, value) if query.device.type != "cuda": raise ValueError("Query, key, and value must be on a CUDA device.") def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_device_cuda(query, key, value) if torch.cuda.get_device_capability(query.device) < (major, minor): raise ValueError( f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." ) return check_device_cuda def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: if query.dtype != key.dtype: raise ValueError("Query and key must have the same dtype.") if query.dtype != value.dtype: raise ValueError("Query and value must have the same dtype.") def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_qkv_dtype_match(query, key, value) if query.dtype not in (torch.bfloat16, torch.float16): raise ValueError("Query, key, and value must be either bfloat16 or float16.") def _check_shape( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs, ) -> None: if query.shape[-1] != key.shape[-1]: raise ValueError("Query and key must have the same last dimension.") if query.shape[-2] != value.shape[-2]: raise ValueError("Query and value must have the same second to last dimension.") if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: raise ValueError("Attention mask must match the key's second to last dimension.") def _prepare_for_flash_attn_or_sage_varlen( batch_size: int, seq_len_q: int, seq_len_kv: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, ) -> None: seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) if attn_mask is None: seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) else: seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in FlashAttention/Sage varlen. Supports 1D to 4D shapes and common broadcasting patterns. """ if attn_mask.dtype != torch.bool: raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") if attn_mask.ndim == 1: # [seq_len_k] -> broadcast across batch attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) elif attn_mask.ndim == 2: # [batch_size, seq_len_k]. Maybe broadcast across batch if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." ) attn_mask = attn_mask.expand(batch_size, seq_len_k) elif attn_mask.ndim == 3: # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." ) attn_mask = attn_mask.any(dim=1) attn_mask = attn_mask.expand(batch_size, seq_len_k) elif attn_mask.ndim == 4: # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." ) attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] else: raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") if attn_mask.shape != (batch_size, seq_len_k): raise ValueError( f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" ) return attn_mask def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx # ===== Attention provider implementations ===== # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 class _flash_attn_flash_attention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, return_softmax: bool = False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( query=q, key=k, value=v, dropout_p=dropout_p, scale=softmax_scale, is_causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax, ) ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) return (out, lse) if return_softmax else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): q, k, v, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( grad_out=grad_out, query=q, key=k, value=v, out=out, logsumexp=lse, dropout_p=ctx.dropout_p, scale=ctx.softmax_scale, is_causal=ctx.causal, window_size=ctx.window_size, softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, rng_state=rng_state, ) return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 class _native_ring_flash_attn_flash_attention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, return_softmax: bool = False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) # For ring flash attention using the flash-attn repo, we want the LSE but flash-attn only supports it if dropout_p > 0 dropout_p = dropout_p if dropout_p > 0 else 1e-30 ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=_finetrainers_flash_attn_forward, query=q, key=k, value=v, dropout_p=dropout_p, scale=softmax_scale, is_causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True, ) ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) return (out, lse) if return_softmax else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): q, k, v, out, lse, rng_state = ctx.saved_tensors lse = lse.permute(0, 2, 1).contiguous() # [B, N, S] -> [B, S, N] grad_query, grad_key, grad_value = _templated_ring_attention_backward( mesh=_AttentionProviderRegistry._mesh, # This needs to be 1 because q, k, v, out_padded returned from forward are BSND instead of BNSD # The grad_out permutation is handled in _finetrainers_flash_attn_backward, and the outputs from that are expected to have # shape BSND instead of BNSD (requirement of _templated_ring_attention_backward), so we need to set seq_dim=1 and permute the # returned outputs seq_dim=1, op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), grad_out=grad_out, grad_out_name="grad_out", query=q, key=k, value=v, out=out, logsumexp=lse, dropout_p=ctx.dropout_p, scale=ctx.softmax_scale, is_causal=ctx.causal, window_size=ctx.window_size, softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, rng_state=rng_state, ) grad_query, grad_key, grad_value = ( x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None @_AttentionProviderRegistry.register( AttentionProvider.FLASH, constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_cp=True, ) def flash_attn_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, return_lse: bool = False, ) -> torch.Tensor: dispatch_fn = ( _native_ring_flash_attn_flash_attention if _AttentionProviderRegistry.context_parallel_enabled() else _flash_attn_flash_attention ) return dispatch_fn.apply( query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse ) @_AttentionProviderRegistry.register( AttentionProvider.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_cp=False, ) def _flash_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, return_attn_probs: bool = False, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, _, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) if _AttentionProviderRegistry.context_parallel_enabled(): return_attn_probs = True out = flash_attn_varlen_func( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, ) rest = None if return_attn_probs: out, *rest = out out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() if return_attn_probs: return out, *rest[:1] return out @_AttentionProviderRegistry.register( AttentionProvider.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], supports_cp=False, ) def _native_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, kernel_options: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: # TODO: should we LRU cache the block mask creation? score_mod = None block_mask = None batch_size, num_heads, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): block_mask = attn_mask elif is_causal: block_mask = flex_attention.create_block_mask( _flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device ) elif torch.is_tensor(attn_mask): if attn_mask.ndim == 2: attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) if attn_mask.dtype == torch.bool: # TODO: this probably does not work but verify! def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] block_mask = flex_attention.create_block_mask( mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device ) else: def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] else: raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") return flex_attention.flex_attention( query=query, key=key, value=value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa, return_lse=return_lse, kernel_options=None, ) @_AttentionProviderRegistry.register( AttentionProvider.NATIVE, constraints=[_check_device, _check_shape], supports_cp=False, ) def _native_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, ) -> torch.Tensor: return native_sdpa( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) class _native_cudnn_attention(torch.autograd.Function): # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) # backward declaration: # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.attn_mask = attn_mask out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=attn_mask, compute_log_sumexp=True, dropout_p=dropout_p, is_causal=is_causal, return_debug_mask=False, scale=scale, ) ) ctx.max_q = max_q ctx.max_k = max_k ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( grad_out=grad_out, query=query, key=key, value=value, out=out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, attn_bias=ctx.attn_mask, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, ) return grad_query, grad_key, grad_value, None, None, None, None, None class _native_ring_native_cudnn_attention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.attn_mask = attn_mask out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( _templated_ring_attention( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=torch.ops.aten._scaled_dot_product_cudnn_attention, query=query, key=key, value=value, attn_bias=attn_mask, compute_log_sumexp=True, dropout_p=dropout_p, is_causal=is_causal, return_debug_mask=False, scale=scale, ) ) ctx.max_q = max_q ctx.max_k = max_k ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_query, grad_key, grad_value = _templated_ring_attention_backward( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, grad_out=grad_out, grad_out_name="grad_out", query=query, key=key, value=value, out=out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, attn_bias=ctx.attn_mask, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, ) return grad_query, grad_key, grad_value, None, None, None, None, None @_AttentionProviderRegistry.register( AttentionProvider._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_cp=True, ) def native_cudnn_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: dispatch_fn = ( _native_ring_native_cudnn_attention if _AttentionProviderRegistry.context_parallel_enabled() else _native_cudnn_attention ) return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) class _native_efficient_attention(torch.autograd.Function): # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 # forward declaration: # aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) # backward declaration: # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.attn_mask = attn_mask # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( query=query, key=key, value=value, attn_bias=attn_mask, compute_log_sumexp=True, dropout_p=dropout_p, is_causal=is_causal, scale=scale, ) ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. grad_query, grad_key, grad_value, grad_attn_bias = ( _finetrainers_scaled_dot_product_efficient_attention_backward( grad_out_=grad_out, query=query, key=key, value=value, attn_bias=ctx.attn_mask, out=out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, dropout_p=ctx.dropout_p, grad_input_mask=[True, True, True, False], is_causal=ctx.is_causal, scale=ctx.scale, ) ) return grad_query, grad_key, grad_value, None, None, None, None, None class _native_ring_native_efficient_attention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.attn_mask = attn_mask # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. out, lse, philox_seed, philox_offset = _templated_ring_attention( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=_finetrainers_scaled_dot_product_efficient_attention_forward, query=query, key=key, value=value, attn_bias=attn_mask, compute_log_sumexp=True, dropout_p=dropout_p, is_causal=is_causal, scale=scale, ) ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=_finetrainers_scaled_dot_product_efficient_attention_backward, grad_out=grad_out, grad_out_name="grad_out_", query=query, key=key, value=value, attn_bias=ctx.attn_mask, out=out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, dropout_p=ctx.dropout_p, grad_input_mask=[True, True, True, False], is_causal=ctx.is_causal, scale=ctx.scale, ) return grad_query, grad_key, grad_value, None, None, None, None, None @_AttentionProviderRegistry.register( AttentionProvider._NATIVE_EFFICIENT, constraints=[_check_device, _check_shape], supports_cp=True, ) def native_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, ) -> torch.Tensor: dispatch_fn = ( _native_ring_native_efficient_attention if _AttentionProviderRegistry.context_parallel_enabled() else _native_efficient_attention ) return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) class _native_flash_attention(torch.autograd.Function): # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910 # forward declaration: # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) # backward declaration: # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_flash_attention( query=query, key=key, value=value, dropout_p=dropout_p, is_causal=is_causal, return_debug_mask=False, scale=scale, ) ) ctx.max_q = max_q ctx.max_k = max_k ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( grad_out=grad_out, query=query, key=key, value=value, out=out, logsumexp=lse, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, philox_seed=philox_seed, philox_offset=philox_offset, scale=ctx.scale, ) return grad_query, grad_key, grad_value, None, None, None, None class _native_ring_native_flash_attention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( _templated_ring_attention( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=torch.ops.aten._scaled_dot_product_flash_attention, query=query, key=key, value=value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, ) ) ctx.max_q = max_q ctx.max_k = max_k ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args: torch.Tensor, ): _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( mesh=_AttentionProviderRegistry._mesh, seq_dim=2, op=torch.ops.aten._scaled_dot_product_flash_attention_backward, grad_out=grad_out, grad_out_name="grad_out", query=query, key=key, value=value, out=out, logsumexp=lse, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, philox_seed=philox_seed, philox_offset=philox_offset, ) return grad_query, grad_key, grad_value, None, None, None, None @_AttentionProviderRegistry.register( AttentionProvider._NATIVE_FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_cp=True, ) def native_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: dispatch_fn = ( _native_ring_native_flash_attention if _AttentionProviderRegistry.context_parallel_enabled() else _native_flash_attention ) return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) # class _native_math_attention(torch.autograd.Function): # # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901 # # forward declaration: # # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) # # backward declaration: # # does not exist # @staticmethod # def forward( # ctx: torch.autograd.function.FunctionCtx, # query: torch.Tensor, # key: torch.Tensor, # value: torch.Tensor, # attn_mask: Optional[torch.Tensor] = None, # dropout_p: float = 0.0, # is_causal: bool = False, # dropout_mask: Optional[torch.Tensor] = None, # scale: Optional[float] = None, # enable_gqa: bool = False, # return_scores: bool = False, # ): # ctx.dropout_p = dropout_p # ctx.is_causal = is_causal # ctx.scale = scale # ctx.enable_gqa = enable_gqa # print(f"query.shape: {query.shape}") # with torch.enable_grad(): # out, scores = torch.ops.aten._scaled_dot_product_attention_math( # query=query, # key=key, # value=value, # attn_mask=attn_mask, # dropout_p=dropout_p, # is_causal=is_causal, # dropout_mask=dropout_mask, # scale=scale, # enable_gqa=enable_gqa, # ) # ctx.save_for_backward(query, key, value, out) # return (out, scores) if return_scores else out # @staticmethod # def backward( # ctx: torch.autograd.function.FunctionCtx, # grad_out: torch.Tensor, # ): # raise NotImplementedError("Backward pass for native math attention is not implemented.") @_AttentionProviderRegistry.register( AttentionProvider._NATIVE_MATH, constraints=[_check_device, _check_shape], supports_cp=False, ) def native_math_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, ) -> torch.Tensor: with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): return native_sdpa( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) @_AttentionProviderRegistry.register( AttentionProvider.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_cp=False, ) def _sage_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: if _AttentionProviderRegistry.context_parallel_enabled(): return_lse = True kwargs = { "q": query, "k": key, "v": value, "tensor_layout": "HND", "is_causal": is_causal, "sm_scale": scale, "return_lse": return_lse, } out = sageattn(**kwargs) rest = None if return_lse: out, *rest = out if return_lse: return out, *rest[:1] return out @_AttentionProviderRegistry.register( AttentionProvider.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, is_causal: bool = False, scale: Optional[float] = None, smooth_k: bool = True, attn_mask: Optional[torch.Tensor] = None, enable_gqa: bool = False, ) -> torch.Tensor: batch_size, _, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) if enable_gqa: # TODO pass if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) out = sageattn_varlen( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, is_causal=is_causal, sm_scale=scale, smooth_k=smooth_k, ) out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() return out @_AttentionProviderRegistry.register( AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], supports_cp=False, ) def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, v=value, tensor_layout="HND", is_causal=is_causal, qk_quant_gran=qk_quant_gran, sm_scale=scale, pv_accum_dtype=pv_accum_dtype, smooth_k=smooth_k, smooth_v=smooth_v, return_lse=return_lse, ) @_AttentionProviderRegistry.register( AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], supports_cp=False, ) def _sage_qk_int8_pv_fp8_cuda_sm90_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", smooth_k: bool = True, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, v=value, tensor_layout="HND", is_causal=is_causal, qk_quant_gran=qk_quant_gran, sm_scale=scale, pv_accum_dtype=pv_accum_dtype, smooth_k=smooth_k, return_lse=return_lse, ) @_AttentionProviderRegistry.register( AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], supports_cp=False, ) def _sage_qk_int8_pv_fp16_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, v=value, tensor_layout="HND", is_causal=is_causal, qk_quant_gran=qk_quant_gran, sm_scale=scale, pv_accum_dtype=pv_accum_dtype, smooth_k=smooth_k, smooth_v=smooth_v, return_lse=return_lse, ) @_AttentionProviderRegistry.register( AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], supports_cp=False, ) def _sage_qk_int8_pv_fp16_triton_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", smooth_k: bool = True, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, v=value, tensor_layout="HND", quantization_backend=quantization_backend, is_causal=is_causal, sm_scale=scale, smooth_k=smooth_k, return_lse=return_lse, ) @_AttentionProviderRegistry.register( AttentionProvider.XFORMERS, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _xformers_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, ) -> torch.Tensor: batch_size, num_heads_q, seq_len_q, _ = query.shape _, num_heads_kv, seq_len_kv, _ = key.shape # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns if is_causal: attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers # query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) if enable_gqa: if num_heads_q % num_heads_kv != 0: raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") num_heads_per_group = num_heads_q // num_heads_kv query = query.unflatten(2, (num_heads_kv, -1)) key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) if enable_gqa: out = out.flatten(2, 3) out = out.permute(0, 2, 1, 3) # .contiguous() return out