""" Copyright (c) 2024 by SageAttention team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn.functional as F from sageattention.triton.quant_per_block import per_block_int8 as per_block_int8_triton from sageattention.triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton from sageattention.triton.attn_qk_int8_per_block import forward as attn_false from sageattention.triton.attn_qk_int8_per_block_causal import forward as attn_true from sageattention.triton.attn_qk_int8_block_varlen import forward as attn_false_varlen from sageattention.triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen from sageattention.triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton try: from sageattention import _qattn_sm80 SM80_ENABLED = True except: SM80_ENABLED = False try: from sageattention import _qattn_sm89 SM89_ENABLED = True except: SM89_ENABLED = False try: from sageattention import _qattn_sm90 SM90_ENABLED = True except: SM90_ENABLED = False from sageattention.quant import per_block_int8 as per_block_int8_cuda from sageattention.quant import per_warp_int8 as per_warp_int8_cuda from sageattention.quant import sub_mean from sageattention.quant import per_channel_fp8 from typing import Any, List, Literal, Optional, Tuple, Union import warnings import os def is_sage2_supported(): device_count = torch.cuda.device_count() for i in range(device_count): major, minor = torch.cuda.get_device_capability(i) if major < 8: return False return True def get_cuda_arch_versions(): cuda_archs = [] for i in range(torch.cuda.device_count()): major, minor = torch.cuda.get_device_capability(i) cuda_archs.append(f"sm{major}{minor}") return cuda_archs def sageattn( qkv_list, tensor_layout: str = "HND", is_causal: bool = False, sm_scale: Optional[float] = None, return_lse: bool = False, **kwargs: Any, ): """ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. """ arch = get_cuda_arch_versions()[qkv_list[0].device.index] if arch == "sm80": return sageattn_qk_int8_pv_fp16_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") elif arch == "sm86": return sageattn_qk_int8_pv_fp16_triton(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) elif arch == "sm89": return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") elif arch == "sm90": return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") elif arch == "sm120": return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. else: raise ValueError(f"Unsupported CUDA architecture: {arch}") @torch.compiler.disable def sageattn_qk_int8_pv_fp16_triton( qkv_list, # q: torch.Tensor, # k: torch.Tensor, # v: torch.Tensor, tensor_layout: str = "HND", quantization_backend: str = "triton", is_causal: bool =False, sm_scale: Optional[float] = None, smooth_k: bool = True, return_lse: bool = False, **kwargs: Any, ) -> torch.Tensor: """ SageAttention with per-block INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using Triton. The FP16 accumulator is added to a FP32 buffer immediately after each iteration. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". quantization_backend : str The quantization backend, either "triton" or "cuda". "cuda" backend offers better performance due to kernel fusion. is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ q, k, v = qkv_list qkv_list.clear() dtype = q.dtype assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." # FIXME(DefTruth): make sage attention work compatible with distributed # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. torch.cuda.set_device(v.device) head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." seq_dim = 1 if tensor_layout == "NHD" else 2 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32) else: lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32) else: km = None if dtype == torch.bfloat16 or dtype == torch.float32: v = v.to(torch.float16) if sm_scale is None: sm_scale = 1.0 / (head_dim_og ** 0.5) if quantization_backend == "triton": q_int8, q_scale, k_int8, k_scale = per_block_int8_triton(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout) elif quantization_backend == "cuda": q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout) else: raise ValueError(f"Unsupported quantization backend: {quantization_backend}") del q,k, km if is_causal: o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse) else: o, lse = attn_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse) o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o @torch.compiler.disable def sageattn_varlen( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, sm_scale: Optional[float] = None, smooth_k: bool = True, **kwargs: Any, ) -> torch.Tensor: """ Parameters ---------- q : torch.Tensor The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. k : torch.Tensor The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. v : torch.Tensor The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. cu_seqlens_q : torch.Tensor The cumulative sequence lengths for the query sequences in the batch, used to index into `q`. Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. cu_seqlens_k : torch.Tensor The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`. Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. max_seqlen_q : int The maximum sequence length for the query tensor in the batch. max_seqlen_k : int The maximum sequence length for the key and value tensors in the batch. is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence. Default: False. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. Returns ------- torch.Tensor The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. - The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``. - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ dtype = q.dtype assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." # FIXME(DefTruth): make sage attention work compatible with distributed # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. torch.cuda.set_device(v.device) head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), "cu_seqlens_q and cu_seqlens_k must be contiguous." if dtype == torch.bfloat16 or dtype == torch.float32: v = v.to(torch.float16) if smooth_k: km = k.mean(dim=0, keepdim=True) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel. k = k - km if sm_scale is None: sm_scale = 1.0 / (head_dim_og ** 0.5) q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale = per_block_int8_varlen_triton(q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale) if is_causal: o = attn_true_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype) else: o = attn_false_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype) o = o[..., :head_dim_og] return o @torch.compiler.disable def sageattn_qk_int8_pv_fp16_cuda( qkv_list, # q: torch.Tensor, # k: torch.Tensor, # v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, qk_quant_gran: str = "per_thread", sm_scale: Optional[float] = None, pv_accum_dtype: str = "fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, **kwargs: Any, ) -> torch.Tensor: """ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. qk_quant_gran : str The granularity of quantization for Q and K, either "per_warp" or "per_thread". Default: "per_thread". sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. pv_accum_dtype : str The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32". smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. smooth_v : bool Whether to smooth the value tensor by subtracting the mean along the sequence dimension. smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". Default: False. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ q,k,v = qkv_list qkv_list.clear() dtype = q.dtype assert SM80_ENABLED, "SM80 kernel is not available. make sure you GPUs with compute capability 8.0 or higher." assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." # FIXME(DefTruth): make sage attention work compatible with distributed # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 _return_lse = 1 if return_lse else 0 head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32) else: lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32) else: km = None if qk_quant_gran == "per_warp": q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64) elif qk_quant_gran == "per_thread": q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64, WARPK=64) q_size = q.size() q_device = q.device del q,k, km o = torch.empty(q_size, dtype=dtype, device=q_device) if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") smooth_v = False if pv_accum_dtype == 'fp32': v = v.to(torch.float16) lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) del v lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: v = v.to(torch.float16) lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16+fp32": v = v.to(torch.float16) lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o @torch.compiler.disable def sageattn_qk_int8_pv_fp8_cuda( qkv_list, tensor_layout: str = "HND", is_causal: bool = False, qk_quant_gran: str = "per_thread", sm_scale: Optional[float] = None, pv_accum_dtype: str = "fp32+fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, **kwargs: Any, ) -> torch.Tensor: """ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. qk_quant_gran : str The granularity of quantization for Q and K, either "per_warp" or "per_thread". Default: "per_thread". sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. pv_accum_dtype : str The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32+fp32". smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. smooth_v : bool Whether to smooth the value tensor by subtracting the mean along the sequence dimension. smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". Default: False. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ q, k, v = qkv_list qkv_list.clear() dtype = q.dtype assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9." assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." # FIXME(DefTruth): make sage attention work compatible with distributed # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 _return_lse = 1 if return_lse else 0 head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32) else: lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32) else: km = None if qk_quant_gran == "per_warp": q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64) elif qk_quant_gran == "per_thread": q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64) q_size = q.size() q_device = q.device del q,k,km if pv_accum_dtype == 'fp32+fp32' and smooth_v: warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") smooth_v = False v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v) del v o = torch.empty(q_size, dtype=dtype, device=q_device) if pv_accum_dtype == "fp32": if smooth_v: lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o @torch.compiler.disable def sageattn_qk_int8_pv_fp8_window_cuda( qkv_list, # q: torch.Tensor, # k: torch.Tensor, # v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, qk_quant_gran: str = "per_thread", sm_scale: Optional[float] = None, pv_accum_dtype: str = "fp32+fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, window = -1, **kwargs: Any, ) -> torch.Tensor: """ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. qk_quant_gran : str The granularity of quantization for Q and K, either "per_warp" or "per_thread". Default: "per_thread". sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. pv_accum_dtype : str The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32+fp32". smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. smooth_v : bool Whether to smooth the value tensor by subtracting the mean along the sequence dimension. smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". Default: False. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ q,k,v = qkv_list qkv_list.clear() dtype = q.dtype assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9." assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." # FIXME(DefTruth): make sage attention work compatible with distributed # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 _return_lse = 1 if return_lse else 0 head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32) else: lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32) else: km = None if qk_quant_gran == "per_warp": q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64) elif qk_quant_gran == "per_thread": q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64) q_size = q.size() q_device = q.device del q,k if pv_accum_dtype == 'fp32+fp32' and smooth_v: warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") smooth_v = False v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v) del v o = torch.empty(q_size, dtype=dtype, device=q_device) if pv_accum_dtype == "fp32": if smooth_v: lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) else: lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) elif pv_accum_dtype == "fp32+fp32": lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o @torch.compiler.disable def sageattn_qk_int8_pv_fp8_cuda_sm90( qkv_list, # q: torch.Tensor, # k: torch.Tensor, # v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, qk_quant_gran: str = "per_thread", sm_scale: Optional[float] = None, pv_accum_dtype: str = "fp32+fp32", smooth_k: bool = True, return_lse: bool = False, **kwargs: Any, ) -> torch.Tensor: """ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. qk_quant_gran : str The granularity of quantization for Q and K, either "per_warp" or "per_thread". Default: "per_thread". sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. pv_accum_dtype : str The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32+fp32". smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ q,k,v = qkv_list qkv_list.clear() dtype = q.dtype assert SM90_ENABLED, "SM90 kernel is not available. Make sure you GPUs with compute capability 9.0." assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 _return_lse = 1 if return_lse else 0 head_dim_og = q.size(-1) if head_dim_og < 64: q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) elif head_dim_og > 64 and head_dim_og < 128: q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32) else: lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32) else: km = None if qk_quant_gran == "per_warp": q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128) elif qk_quant_gran == "per_thread": q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128) q_size = q.size() kv_len = k.size(seq_dim) q_device = q.device del q,k # pad v to multiple of 128 # TODO: modify per_channel_fp8 kernel to handle this v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 if v_pad_len > 0: if tensor_layout == "HND": v = torch.cat([v, torch.zeros(v.size(0), v.size(1), v_pad_len, v.size(3), dtype=v.dtype, device=v.device)], dim=2) else: v = torch.cat([v, torch.zeros(v.size(0), v_pad_len, v.size(2), v.size(3), dtype=v.dtype, device=v.device)], dim=1) v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) del v o = torch.empty(q_size, dtype=dtype, device=q_device) if pv_accum_dtype == "fp32": raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o