kernel

The latest build b58ed97 may contain an accuracy issue, which is currently being addressed. Please use with caution, and be aware that corrected outputs will be available soon.

Flash Attention

Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.

Original code here https://github.com/Dao-AILab/flash-attention.

# /// script
# dependencies = ["numpy", "torch", "kernels"]
# ///
import torch
from kernels import get_kernel

# Setup
torch.manual_seed(42)
flash_attn = get_kernel("kernels-community/flash-attn")
device = torch.device("cuda")

# Show available functions
print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])

# 1. Standard attention
print("\n1. Standard attention:")
B, S, H, D = 2, 5, 4, 8  # batch, seq_len, heads, head_dim
q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0]
print(f"Output: {out.shape}")

# 2. Variable length sequences
print("\n2. Variable length sequences:")
q_var = torch.randn(10, H, D, device=device, dtype=torch.float16)  # total_q=10
k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16)  # total_k=12
# For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k
cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32)
cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
out_var = flash_attn.mha_varlen_fwd(
    q=q_var,
    k=k_var,
    v=v_var,
    cu_seqlens_q=cu_q,
    cu_seqlens_k=cu_k,
    max_seqlen_q=4,
    max_seqlen_k=5,
)[0]
print(f"Output: {out_var.shape}")

# 3. KV-cache for autoregressive generation
print("\n3. KV-cache:")
cache_len, new_len = 10, 2
kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16)
q_new = k_new = v_new = torch.randn(
    B, new_len, H, D, device=device, dtype=torch.float16
)
seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32)
out_kv = flash_attn.mha_fwd_kvcache(
    q=q_new,
    kcache=kcache,
    vcache=vcache,
    k=k_new,
    v=v_new,
    seqlens_k=seqlens,
    is_causal=True,
)[0]
print(f"Output: {out_kv.shape}")

expected output

Fetching 3 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 16384.00it/s]
Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']

1. Standard attention:
Output: torch.Size([2, 5, 4, 8])

2. Variable length sequences:
Output: torch.Size([10, 4, 8])

3. KV-cache:
Output: torch.Size([2, 2, 4, 8])
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support