File size: 1,725 Bytes
de70d68 567c8ad 8830f14 567c8ad 8830f14 22b535b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
---
license: mit
---
# triton-kernels
triton-kernels is a set of kernels that enable fast moe on different architectures. These kernels are compatible with different precision (e.g bf16, mxfp4)
Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels
The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c
## Quickstart
```bash
uv run https://huggingface.co/kernels-community/triton_kernels/raw/main/readme_example.py
```
```python
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "triton",
# "numpy",
# "kernels",
# ]
# ///
import torch
import sys
from kernels import get_kernel
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Load triton_kernels module via kernels library
triton_kernels = get_kernel("kernels-community/triton_kernels")
# Access modules directly from the loaded kernel
swiglu = triton_kernels.swiglu
routing = triton_kernels.routing
# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# SwiGLU example
x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16)
y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"SwiGLU: {x.shape} -> {y.shape}")
# Routing example
logits = torch.randn(128, 8, device=device, dtype=torch.float16)
routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2)
print(f"Routing: {routing_data.expt_hist.sum()} tokens routed")
# MoE integrated
n_tokens = routing_data.expt_hist.sum().item()
x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16)
y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}")
``` |